from dataclasses import dataclass, fields
from enum import Enum
from typing import (
Annotated,
Any,
Callable,
ItemsView,
Iterator,
KeysView,
Literal,
Protocol,
TypeVar,
ValuesView,
)
import numpy as np
from annotated_types import Ge, Gt, Le, Lt
from numpy._typing import NDArray
PyTree = Any
PyTreeRegistry = dict[type | str, dict[str, Callable[[Any], Any]]]
Scalar = Any
T = TypeVar("T")
[docs]
class AggregationLevel(Enum):
"""Enum to specify the aggregation level of objective functions and solvers."""
SCALAR = "scalar"
LEAST_SQUARES = "least_squares"
LIKELIHOOD = "likelihood"
[docs]
class Direction(str, Enum):
"""Enum to specify the direction of optimization."""
MINIMIZE = "minimize"
MAXIMIZE = "maximize"
[docs]
@dataclass(frozen=True)
class DictLikeAccess:
"""Useful base class for replacing string-based dictionaries with dataclass
instances and keeping backward compatability regarding read access to the data
structure."""
def __getitem__(self, key: str) -> Any:
if key in self.__dict__:
return getattr(self, key)
else:
raise KeyError(f"{key} not found in {self.__class__.__name__}")
def __iter__(self) -> Iterator[str]:
return iter(self._dict_repr())
def _dict_repr(self) -> dict[str, Any]:
return {field.name: getattr(self, field.name) for field in fields(self)}
[docs]
def keys(self) -> KeysView[str]:
return self._dict_repr().keys()
[docs]
def items(self) -> ItemsView[str, Any]:
return self._dict_repr().items()
[docs]
def values(self) -> ValuesView[str]:
return self._dict_repr().values()
[docs]
@dataclass(frozen=True)
class TupleLikeAccess:
"""Useful base class for replacing tuples with dataclass instances and keeping
backward compatability regarding read access to the data structure."""
def __getitem__(self, index: int | slice) -> Any:
field_values = [getattr(self, field.name) for field in fields(self)]
return field_values[index]
def __len__(self) -> int:
return len(fields(self))
def __iter__(self) -> Iterator[str]:
for field in fields(self):
yield getattr(self, field.name)
[docs]
class ErrorHandling(Enum):
"""Enum to specify the error handling strategy of the optimization algorithm."""
RAISE = "raise"
RAISE_STRICT = "raise_strict"
CONTINUE = "continue"
[docs]
class EvalTask(Enum):
"""Enum to specify the task of the evaluation function."""
FUN = "fun"
JAC = "jac"
FUN_AND_JAC = "fun_and_jac"
EXPLORATION = "exploration"
[docs]
class BatchEvaluator(Protocol):
def __call__(
self,
func: Callable[..., T],
arguments: list[Any],
n_cores: int = 1,
error_handling: ErrorHandling
| Literal["raise", "continue"] = ErrorHandling.CONTINUE,
unpack_symbol: Literal["*", "**"] | None = None,
) -> list[T]:
pass
PositiveInt = Annotated[int, Gt(0)]
"""Type alias for positive integers (greater than 0)."""
NonNegativeInt = Annotated[int, Ge(0)]
"""Type alias for non-negative integers (greater than or equal to 0)."""
PositiveFloat = Annotated[float, Gt(0)]
"""Type alias for positive floats (greater than 0)."""
NonNegativeFloat = Annotated[float, Ge(0)]
"""Type alias for non-negative floats (greater than or equal to 0)."""
ProbabilityFloat = Annotated[float, Ge(0), Le(1)]
"""Type alias for probability floats (between 0 and 1, inclusive)."""
NegativeFloat = Annotated[float, Lt(0)]
"""Type alias for negative floats (less than 0)."""
GtOneFloat = Annotated[float, Gt(1)]
"""Type alias for floats greater than 1."""
UnitIntervalFloat = Annotated[float, Gt(0), Le(1)]
"""Type alias for floats in (0, 1]."""
YesNoBool = Literal["yes", "no"] | bool
"""Type alias for boolean values represented as 'yes' or 'no' strings or as boolean
values."""
DirectionLiteral = Literal["minimize", "maximize"]
"""Type alias for optimization direction, either 'minimize' or 'maximize'."""
BatchEvaluatorLiteral = Literal["joblib", "pathos", "threading"]
"""Type alias for batch evaluator types, can be 'joblib', 'pathos', or 'threading'."""
ErrorHandlingLiteral = Literal["raise", "continue"]
"""Type alias for error handling strategies, can be 'raise' or 'continue'."""
[docs]
@dataclass(frozen=True)
class IterationHistory(DictLikeAccess):
"""History of iterations in a process.
Attributes:
params: A list of parameters used in each iteration.
criterion: A list of criterion values obtained in each iteration.
runtime: A list or array of runtimes associated with each iteration.
"""
params: list[PyTree]
fun: list[float]
time: list[float] | NDArray[np.float64]
[docs]
@dataclass(frozen=True)
class MultiStartIterationHistory(TupleLikeAccess):
"""History of multiple start iterations.
Attributes:
history: The main iteration history, representing the best end value.
local_histories: Optional, a list of local iteration histories.
exploration: Optional, iteration history for exploration steps.
"""
history: IterationHistory
local_histories: list[IterationHistory] | None = None
exploration: IterationHistory | None = None