import inspect
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Literal
import numpy as np
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten
from optimagic.config import DEFAULT_PALETTE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history import History
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
from optimagic.typing import IterationHistory, PyTree
from optimagic.visualization.backends import line_plot
from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle
BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = {
"plotly": {
"yanchor": "top",
"xanchor": "right",
"y": 0.95,
"x": 0.95,
},
"matplotlib": {
"loc": "upper right",
},
}
ResultOrPath = OptimizeResult | str | Path
[docs]
def criterion_plot(
results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath],
names: list[str] | str | None = None,
max_evaluations: int | None = None,
backend: Literal["plotly", "matplotlib"] = "plotly",
template: str | None = None,
palette: list[str] | str = DEFAULT_PALETTE,
stack_multistart: bool = False,
monotone: bool = False,
show_exploration: bool = False,
) -> Any:
"""Plot the criterion history of an optimization.
Args:
results: An optimization result (or list of, or dict of results) with collected
history, or path(s) to it. If dict, then the key is used as the name in the
legend.
max_evaluations: Clip the criterion history after that many entries.
backend: The backend to use for plotting. Default is "plotly".
template: The template for the figure. If not specified, the default template of
the backend is used.
palette: The coloring palette for traces. Default is the D3 qualitative palette.
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
monotone: If True, the criterion plot becomes monotone in the sense that at each
iteration the current best criterion value is displayed. Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
Returns:
The figure object containing the criterion plot.
"""
# ==================================================================================
# Process inputs
palette_cycle = get_palette_cycle(palette)
dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names)
# ==================================================================================
# Extract backend-agnostic plotting data from results
list_of_optimize_data = _retrieve_optimization_data_from_results(
results=dict_of_optimize_results_or_paths,
stack_multistart=stack_multistart,
show_exploration=show_exploration,
plot_name="criterion_plot",
)
lines, multistart_lines = _extract_criterion_plot_lines(
data=list_of_optimize_data,
max_evaluations=max_evaluations,
palette_cycle=palette_cycle,
stack_multistart=stack_multistart,
monotone=monotone,
)
# ==================================================================================
# Generate the figure
fig = line_plot(
lines=multistart_lines + lines,
backend=backend,
xlabel="No. of criterion evaluations",
ylabel="Criterion value",
template=template,
legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None),
)
return fig
def _harmonize_inputs_to_dict(
results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath],
names: list[str] | str | None,
) -> dict[str, ResultOrPath]:
"""Convert all valid inputs for results and names to dict[str, OptimizeResult]."""
# convert scalar case to list case
if not isinstance(names, list) and names is not None:
names = [names]
if isinstance(results, (OptimizeResult, str, Path)):
results = [results]
if names is not None and len(names) != len(results):
raise ValueError("len(results) needs to be equal to len(names).")
# handle dict case
if isinstance(results, dict):
if names is not None:
results_dict = dict(zip(names, list(results.values()), strict=False))
else:
results_dict = results
# unlabeled iterable of results
else:
if names is None:
names = [str(i) for i in range(len(results))]
results_dict = dict(zip(names, results, strict=False))
# convert keys to strings
results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()}
return results_dict
def _convert_key_to_str(key: Any) -> str:
if inspect.isclass(key) and issubclass(key, Algorithm):
out = str(key.name)
elif isinstance(key, Algorithm):
out = str(key.name)
else:
out = str(key)
return out
[docs]
def params_plot(
result: ResultOrPath,
selector: Callable[[PyTree], PyTree] | None = None,
max_evaluations: int | None = None,
backend: Literal["plotly", "matplotlib"] = "plotly",
template: str | None = None,
palette: list[str] | str = DEFAULT_PALETTE,
show_exploration: bool = False,
) -> Any:
"""Plot the params history of an optimization.
Args:
result: An optimization result with collected history, or path to it.
If dict, then the key is used as the name in the legend.
selector: A callable that takes params and returns a subset of params.
If provided, only the selected subset of params is plotted.
max_evaluations: Clip the criterion history after that many entries.
backend: The backend to use for plotting. Default is "plotly".
template: The template for the figure. If not specified, the default template of
the backend is used.
palette: The coloring palette for traces. Default is the D3 qualitative palette.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
Returns:
The figure object containing the params plot.
"""
# ==================================================================================
# Process inputs
palette_cycle = get_palette_cycle(palette)
# ==================================================================================
# Extract backend-agnostic plotting data from results
optimize_data = _retrieve_optimization_data_from_single_result(
result=result,
stack_multistart=True,
show_exploration=show_exploration,
plot_name="params_plot",
)
lines = _extract_params_plot_lines(
data=optimize_data,
selector=selector,
max_evaluations=max_evaluations,
palette_cycle=palette_cycle,
)
# ==================================================================================
# Generate the figure
fig = line_plot(
lines=lines,
backend=backend,
xlabel="No. of criterion evaluations",
ylabel="Parameter value",
template=template,
legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None),
)
return fig
@dataclass(frozen=True)
class _PlottingMultistartHistory:
"""Data container for an optimization history and metadata. Contains local histories
in case of multistart optimization.
This dataclass is only used internally.
"""
history: History
name: str | None
start_params: PyTree
is_multistart: bool
local_histories: list[History] | list[IterationHistory] | None
stacked_local_histories: History | None
def _retrieve_optimization_data_from_results(
results: dict[str, ResultOrPath],
stack_multistart: bool,
show_exploration: bool,
plot_name: str,
) -> list[_PlottingMultistartHistory]:
# Retrieves data from multiple results by iterating over the results dictionary
# and calling the single result retrieval function.
data = []
for name, res in results.items():
_data = _retrieve_optimization_data_from_single_result(
result=res,
stack_multistart=stack_multistart,
show_exploration=show_exploration,
plot_name=plot_name,
res_name=name,
)
data.append(_data)
return data
def _retrieve_optimization_data_from_single_result(
result: ResultOrPath,
stack_multistart: bool,
show_exploration: bool,
plot_name: str,
res_name: str | None = None,
) -> _PlottingMultistartHistory:
"""Retrieve data from a single result (OptimizeResult or database).
Args:
result: An optimization result with collected history, or path to it.
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
plot_name: Name of the plotting function that calls this function. Used for
raising errors.
res_name: Name of the result.
Returns:
A data object containing the history, metadata, and local histories of the
optimization result.
"""
if isinstance(result, OptimizeResult):
data = _retrieve_optimization_data_from_result_object(
res=result,
stack_multistart=stack_multistart,
show_exploration=show_exploration,
plot_name=plot_name,
res_name=res_name,
)
elif isinstance(result, (str, Path)):
data = _retrieve_optimization_data_from_database(
res=result,
stack_multistart=stack_multistart,
show_exploration=show_exploration,
res_name=res_name,
)
else:
msg = (
"result must be an OptimizeResult or a path to a log file, "
f"but is type {type(result)}."
)
raise TypeError(msg)
return data
def _retrieve_optimization_data_from_result_object(
res: OptimizeResult,
stack_multistart: bool,
show_exploration: bool,
plot_name: str,
res_name: str | None = None,
) -> _PlottingMultistartHistory:
"""Retrieve optimization data from result object.
Args:
res: An optimization result object.
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
plot_name: Name of the plotting function that calls this function. Used for
raising errors.
res_name: Name of the result.
Returns:
A data object containing the history, metadata, and local histories of the
optimization result.
"""
if res.history is None:
msg = f"{plot_name} requires an optimize result with history. Enable history "
"collection by setting collect_history=True when calling maximize or minimize."
raise ValueError(msg)
if res.multistart_info:
local_histories = [
opt.history
for opt in res.multistart_info.local_optima
if opt.history is not None
]
if stack_multistart:
stacked = _get_stacked_local_histories(local_histories, res.direction)
if show_exploration:
fun = res.multistart_info.exploration_results[::-1] + stacked.fun
params = res.multistart_info.exploration_sample[::-1] + stacked.params
stacked = History(
direction=stacked.direction,
fun=fun,
params=params,
# TODO: This needs to be fixed
start_time=len(fun) * [None], # type: ignore
stop_time=len(fun) * [None], # type: ignore
batches=len(fun) * [None], # type: ignore
task=len(fun) * [None], # type: ignore
)
else:
stacked = None
else:
local_histories = None
stacked = None
data = _PlottingMultistartHistory(
history=res.history,
name=res_name,
start_params=res.start_params,
is_multistart=res.multistart_info is not None,
local_histories=local_histories,
stacked_local_histories=stacked,
)
return data
def _retrieve_optimization_data_from_database(
res: str | Path,
stack_multistart: bool,
show_exploration: bool,
res_name: str | None = None,
) -> _PlottingMultistartHistory:
"""Retrieve optimization data from a database.
Args:
res: A path to an optimization database.
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
res_name: Name of the result.
Returns:
A data object containing the history, metadata, and local histories of the
optimization result.
"""
reader: LogReader = LogReader.from_options(SQLiteLogOptions(res))
_problem_table = reader.problem_df
direction = _problem_table["direction"].tolist()[-1]
multistart_history = reader.read_multistart_history(direction)
_history = multistart_history.history
local_histories = multistart_history.local_histories
exploration = multistart_history.exploration
if stack_multistart and local_histories is not None:
stacked = _get_stacked_local_histories(local_histories, direction, _history)
if show_exploration:
stacked["params"] = exploration["params"][::-1] + stacked["params"] # type: ignore
stacked["criterion"] = exploration["criterion"][::-1] + stacked["criterion"] # type: ignore
else:
stacked = None
history = History(
direction=direction,
fun=_history["fun"],
params=_history["params"],
start_time=_history["time"],
# TODO (@janosg): Retrieve `stop_time` from `hist` once it is available.
# https://github.com/optimagic-dev/optimagic/pull/553
stop_time=len(_history["fun"]) * [None], # type: ignore
task=len(_history["fun"]) * [None], # type: ignore
batches=list(range(len(_history["fun"]))),
)
data = _PlottingMultistartHistory(
history=history,
name=res_name,
start_params=reader.read_start_params(),
is_multistart=local_histories is not None,
local_histories=local_histories,
stacked_local_histories=stacked,
)
return data
def _get_stacked_local_histories(
local_histories: list[History] | list[IterationHistory],
direction: Any,
history: History | IterationHistory | None = None,
) -> History:
"""Stack local histories.
Local histories is a list of dictionaries, each of the same structure. We transform
this to a dictionary of lists. Finally, when the data is read from the database we
append the best history at the end.
"""
stacked: dict[str, list[Any]] = {"criterion": [], "params": [], "runtime": []}
for hist in local_histories:
stacked["criterion"].extend(hist.fun)
stacked["params"].extend(hist.params)
stacked["runtime"].extend(hist.time)
# append additional history is necessary
if history is not None:
stacked["criterion"].extend(history.fun)
stacked["params"].extend(history.params)
stacked["runtime"].extend(history.time)
return History(
direction=direction,
fun=stacked["criterion"],
params=stacked["params"],
start_time=stacked["runtime"],
# TODO (@janosg): Retrieve `stop_time` from `hist` once it is available for the
# IterationHistory.
# https://github.com/optimagic-dev/optimagic/pull/553
stop_time=len(stacked["criterion"]) * [None], # type: ignore
task=len(stacked["criterion"]) * [None], # type: ignore
batches=list(range(len(stacked["criterion"]))),
)
def _extract_criterion_plot_lines(
data: list[_PlottingMultistartHistory],
max_evaluations: int | None,
palette_cycle: "itertools.cycle[str]",
stack_multistart: bool,
monotone: bool,
) -> tuple[list[LineData], list[LineData]]:
"""Extract lines for criterion plot from data.
Args:
data: Data retrieved from results or database.
max_evaluations: Clip the criterion history after that many entries.
palette_cycle: Cycle of colors for plotting.
stack_multistart: Whether to combine multistart histories into a single
history. Default is False.
monotone: If True, the criterion plot becomes monotone in the sense that at each
iteration the current best criterion value is displayed.
Returns:
Tuple containing
- lines: Main optimization paths.
- multistart_lines: Multistart optimization paths.
"""
fun_or_monotone_fun = "monotone_fun" if monotone else "fun"
# Collect multistart optimization paths
multistart_lines: list[LineData] = []
plot_multistart = len(data) == 1 and data[0].is_multistart and not stack_multistart
if plot_multistart and data[0].local_histories:
for i, local_history in enumerate(data[0].local_histories):
history = getattr(local_history, fun_or_monotone_fun)
if max_evaluations is not None and len(history) > max_evaluations:
history = history[:max_evaluations]
line_data = LineData(
x=np.arange(len(history)),
y=history,
color="#bab0ac",
name=str(i),
show_in_legend=False,
)
multistart_lines.append(line_data)
# Collect main optimization paths
lines: list[LineData] = []
for _data in data:
if stack_multistart and _data.stacked_local_histories is not None:
_history = _data.stacked_local_histories
else:
_history = _data.history
history = getattr(_history, fun_or_monotone_fun)
if max_evaluations is not None and len(history) > max_evaluations:
history = history[:max_evaluations]
line_data = LineData(
x=np.arange(len(history)),
y=history,
color=next(palette_cycle),
name="best result" if plot_multistart else _data.name,
show_in_legend=not plot_multistart,
)
lines.append(line_data)
return lines, multistart_lines
def _extract_params_plot_lines(
data: _PlottingMultistartHistory,
selector: Callable[[PyTree], PyTree] | None,
max_evaluations: int | None,
palette_cycle: "itertools.cycle[str]",
) -> list[LineData]:
"""Extract lines for params plot from data.
Args:
data: Data retrieved from results or database.
selector: A callable that takes params and returns a subset of params.
If provided, only the selected subset of params is plotted.
max_evaluations: Clip the criterion history after that many entries.
palette_cycle: Cycle of colors for plotting.
Returns:
lines: Parameter histories.
"""
if data.stacked_local_histories is not None:
history = data.stacked_local_histories.params
else:
history = data.history.params
start_params = data.start_params
registry = get_registry(extended=True)
hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T
names = leaf_names(start_params, registry=registry)
if selector is not None:
flat, treedef = tree_flatten(start_params, registry=registry)
helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry)
selected = np.array(tree_just_flatten(selector(helper), registry=registry))
names = [names[i] for i in selected]
hist_arr = hist_arr[selected]
lines: list[LineData] = []
for name, _data in zip(names, hist_arr, strict=False):
if max_evaluations is not None and len(_data) > max_evaluations:
plot_data = _data[:max_evaluations]
else:
plot_data = _data
line_data = LineData(
x=np.arange(len(plot_data)),
y=plot_data,
color=next(palette_cycle),
name=name,
show_in_legend=True,
)
lines.append(line_data)
return lines