import inspect
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import plotly.graph_objects as go
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten
from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history import History
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
from optimagic.typing import IterationHistory, PyTree
OptimizeResultOrPath = OptimizeResult | str | Path
[docs]
def criterion_plot(
results: OptimizeResultOrPath
| list[OptimizeResultOrPath]
| dict[str, OptimizeResultOrPath],
names: list[str] | str | None = None,
max_evaluations: int | None = None,
template: str = PLOTLY_TEMPLATE,
palette: list[str] | str = PLOTLY_PALETTE,
stack_multistart: bool = False,
monotone: bool = False,
show_exploration: bool = False,
) -> go.Figure:
"""Plot the criterion history of an optimization.
Args:
results: A (list or dict of) optimization results with collected history.
If dict, then the key is used as the name in a legend.
names: Names corresponding to res or entries in res.
max_evaluations: Clip the criterion history after that many entries.
template: The template for the figure. Default is "plotly_white".
palette: The coloring palette for traces. Default is "qualitative.Set2".
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
monotone: If True, the criterion plot becomes monotone in the sense that at each
iteration the current best criterion value is displayed. Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
Returns:
The figure object containing the criterion plot.
"""
# ==================================================================================
# Process inputs
if not isinstance(palette, list):
palette = [palette]
palette_cycle = itertools.cycle(palette)
dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names)
# ==================================================================================
# Extract backend-agnostic plotting data from results
list_of_optimize_data = _retrieve_optimization_data(
results=dict_of_optimize_results_or_paths,
stack_multistart=stack_multistart,
show_exploration=show_exploration,
)
lines, multistart_lines = _extract_criterion_plot_lines(
data=list_of_optimize_data,
max_evaluations=max_evaluations,
palette_cycle=palette_cycle,
stack_multistart=stack_multistart,
monotone=monotone,
)
# ==================================================================================
# Generate the plotly figure
plot_config = PlotConfig(
template=template,
legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
)
fig = _plotly_line_plot(lines + multistart_lines, plot_config)
return fig
def _harmonize_inputs_to_dict(
results: OptimizeResultOrPath
| list[OptimizeResultOrPath]
| dict[str, OptimizeResultOrPath],
names: list[str] | str | None,
) -> dict[str, OptimizeResult | str | Path]:
"""Convert all valid inputs for results and names to dict[str, OptimizeResult]."""
# convert scalar case to list case
if not isinstance(names, list) and names is not None:
names = [names]
if isinstance(results, (OptimizeResult, str, Path)):
results = [results]
if names is not None and len(names) != len(results):
raise ValueError("len(results) needs to be equal to len(names).")
# handle dict case
if isinstance(results, dict):
if names is not None:
results_dict = dict(zip(names, list(results.values()), strict=False))
else:
results_dict = results
# unlabeled iterable of results
else:
if names is None:
names = [str(i) for i in range(len(results))]
results_dict = dict(zip(names, results, strict=False))
# convert keys to strings
results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()}
return results_dict
def _convert_key_to_str(key: Any) -> str:
if inspect.isclass(key) and issubclass(key, Algorithm):
out = str(key.name)
elif isinstance(key, Algorithm):
out = str(key.name)
else:
out = str(key)
return out
[docs]
def params_plot(
result,
selector=None,
max_evaluations=None,
template=PLOTLY_TEMPLATE,
show_exploration=False,
):
"""Plot the params history of an optimization.
Args:
result (Union[OptimizeResult, pathlib.Path, str]): An optimization results with
collected history. If dict, then the key is used as the name in a legend.
selector (callable): A callable that takes params and returns a subset
of params. If provided, only the selected subset of params is plotted.
max_evaluations (int): Clip the criterion history after that many entries.
template (str): The template for the figure. Default is "plotly_white".
show_exploration (bool): If True, exploration samples of a multistart
optimization are visualized. Default is False.
Returns:
plotly.graph_objs._figure.Figure: The figure.
"""
# ==================================================================================
# Process inputs
# ==================================================================================
if isinstance(result, OptimizeResult):
data = _retrieve_optimization_data_from_results_object(
result,
stack_multistart=True,
show_exploration=show_exploration,
plot_name="params_plot",
)
start_params = result.start_params
elif isinstance(result, (str, Path)):
data = _retrieve_optimization_data_from_database(
result,
stack_multistart=True,
show_exploration=show_exploration,
)
start_params = data.start_params
else:
raise TypeError("result must be an OptimizeResult or a path to a log file.")
if data.stacked_local_histories is not None:
history = data.stacked_local_histories.params
else:
history = data.history.params
# ==================================================================================
# Create figure
# ==================================================================================
fig = go.Figure()
registry = get_registry(extended=True)
hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T
names = leaf_names(start_params, registry=registry)
if selector is not None:
flat, treedef = tree_flatten(start_params, registry=registry)
helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry)
selected = np.array(tree_just_flatten(selector(helper), registry=registry))
names = [names[i] for i in selected]
hist_arr = hist_arr[selected]
for name, data in zip(names, hist_arr, strict=False):
if max_evaluations is not None and len(data) > max_evaluations:
plot_data = data[:max_evaluations]
else:
plot_data = data
trace = go.Scatter(
x=np.arange(len(plot_data)),
y=plot_data,
mode="lines",
name=name,
)
fig.add_trace(trace)
fig.update_layout(
template=template,
xaxis_title_text="No. of criterion evaluations",
yaxis_title_text="Parameter value",
legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95},
)
return fig
@dataclass(frozen=True)
class _PlottingMultistartHistory:
"""Data container for an optimization history and metadata. Contains local histories
in case of multistart optimization.
This dataclass is only used internally.
"""
history: History
name: str | None
start_params: PyTree
is_multistart: bool
local_histories: list[History] | list[IterationHistory] | None
stacked_local_histories: History | None
def _retrieve_optimization_data(
results: dict[str, OptimizeResult | str | Path],
stack_multistart: bool,
show_exploration: bool,
) -> list[_PlottingMultistartHistory]:
"""Retrieve data for criterion plot from results (OptimizeResult or database).
Args:
results: A dict of optimization results with collected history.
The key is used as the name in a legend.
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
Returns:
A list of objects containing the history, metadata, and local histories of each
optimization result.
"""
data = []
for name, res in results.items():
if isinstance(res, OptimizeResult):
_data = _retrieve_optimization_data_from_results_object(
res=res,
stack_multistart=stack_multistart,
show_exploration=show_exploration,
plot_name="criterion_plot",
res_name=name,
)
elif isinstance(res, (str, Path)):
_data = _retrieve_optimization_data_from_database(
res=res,
stack_multistart=stack_multistart,
show_exploration=show_exploration,
res_name=name,
)
else:
msg = (
"results must be (or contain) an OptimizeResult or a path to a log "
f"file, but is type {type(res)}."
)
raise TypeError(msg)
data.append(_data)
return data
def _retrieve_optimization_data_from_results_object(
res: OptimizeResult,
stack_multistart: bool,
show_exploration: bool,
plot_name: str,
res_name: str | None = None,
) -> _PlottingMultistartHistory:
"""Retrieve optimization data from results object.
Args:
res: An optimization results object.
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
plot_name: Name of the plotting function that calls this function. Used for
raising errors.
res_name: Name of the result.
Returns:
A data object containing the history, metadata, and local histories of the
optimization result.
"""
if res.history is None:
msg = f"{plot_name} requires an optimize result with history. Enable history "
"collection by setting collect_history=True when calling maximize or minimize."
raise ValueError(msg)
if res.multistart_info:
local_histories = [
opt.history
for opt in res.multistart_info.local_optima
if opt.history is not None
]
if stack_multistart:
stacked = _get_stacked_local_histories(local_histories, res.direction)
if show_exploration:
fun = res.multistart_info.exploration_results[::-1] + stacked.fun
params = res.multistart_info.exploration_sample[::-1] + stacked.params
stacked = History(
direction=stacked.direction,
fun=fun,
params=params,
# TODO: This needs to be fixed
start_time=len(fun) * [None], # type: ignore
stop_time=len(fun) * [None], # type: ignore
batches=len(fun) * [None], # type: ignore
task=len(fun) * [None], # type: ignore
)
else:
stacked = None
else:
local_histories = None
stacked = None
data = _PlottingMultistartHistory(
history=res.history,
name=res_name,
start_params=res.start_params,
is_multistart=res.multistart_info is not None,
local_histories=local_histories,
stacked_local_histories=stacked,
)
return data
def _retrieve_optimization_data_from_database(
res: str | Path,
stack_multistart: bool,
show_exploration: bool,
res_name: str | None = None,
) -> _PlottingMultistartHistory:
"""Retrieve optimization data from a database.
Args:
res: A path to an optimization database.
stack_multistart: Whether to combine multistart histories into a single history.
Default is False.
show_exploration: If True, exploration samples of a multistart optimization are
visualized. Default is False.
res_name: Name of the result.
Returns:
A data object containing the history, metadata, and local histories of the
optimization result.
"""
reader: LogReader = LogReader.from_options(SQLiteLogOptions(res))
_problem_table = reader.problem_df
direction = _problem_table["direction"].tolist()[-1]
multistart_history = reader.read_multistart_history(direction)
_history = multistart_history.history
local_histories = multistart_history.local_histories
exploration = multistart_history.exploration
if stack_multistart and local_histories is not None:
stacked = _get_stacked_local_histories(local_histories, direction, _history)
if show_exploration:
stacked["params"] = exploration["params"][::-1] + stacked["params"] # type: ignore
stacked["criterion"] = exploration["criterion"][::-1] + stacked["criterion"] # type: ignore
else:
stacked = None
history = History(
direction=direction,
fun=_history["fun"],
params=_history["params"],
start_time=_history["time"],
# TODO (@janosg): Retrieve `stop_time` from `hist` once it is available.
# https://github.com/optimagic-dev/optimagic/pull/553
stop_time=len(_history["fun"]) * [None], # type: ignore
task=len(_history["fun"]) * [None], # type: ignore
batches=list(range(len(_history["fun"]))),
)
data = _PlottingMultistartHistory(
history=history,
name=res_name,
start_params=reader.read_start_params(),
is_multistart=local_histories is not None,
local_histories=local_histories,
stacked_local_histories=stacked,
)
return data
def _get_stacked_local_histories(
local_histories: list[History] | list[IterationHistory],
direction: Any,
history: History | IterationHistory | None = None,
) -> History:
"""Stack local histories.
Local histories is a list of dictionaries, each of the same structure. We transform
this to a dictionary of lists. Finally, when the data is read from the database we
append the best history at the end.
"""
stacked: dict[str, list[Any]] = {"criterion": [], "params": [], "runtime": []}
for hist in local_histories:
stacked["criterion"].extend(hist.fun)
stacked["params"].extend(hist.params)
stacked["runtime"].extend(hist.time)
# append additional history is necessary
if history is not None:
stacked["criterion"].extend(history.fun)
stacked["params"].extend(history.params)
stacked["runtime"].extend(history.time)
return History(
direction=direction,
fun=stacked["criterion"],
params=stacked["params"],
start_time=stacked["runtime"],
# TODO (@janosg): Retrieve `stop_time` from `hist` once it is available for the
# IterationHistory.
# https://github.com/optimagic-dev/optimagic/pull/553
stop_time=len(stacked["criterion"]) * [None], # type: ignore
task=len(stacked["criterion"]) * [None], # type: ignore
batches=list(range(len(stacked["criterion"]))),
)
@dataclass(frozen=True)
class LineData:
"""Data of a single line.
Attributes:
x: The x-coordinates of the points.
y: The y-coordinates of the points.
color: The color of the line. Default is None.
name: The name of the line. Default is None.
show_in_legend: Whether to show the line in the legend. Default is True.
"""
x: np.ndarray
y: np.ndarray
color: str | None = None
name: str | None = None
show_in_legend: bool = True
def _extract_criterion_plot_lines(
data: list[_PlottingMultistartHistory],
max_evaluations: int | None,
palette_cycle: "itertools.cycle[str]",
stack_multistart: bool,
monotone: bool,
) -> tuple[list[LineData], list[LineData]]:
"""Extract lines for criterion plot from data.
Args:
data: Data retrieved from results or database.
max_evaluations: Clip the criterion history after that many entries.
palette_cycle: Cycle of colors for plotting.
stack_multistart: Whether to combine multistart histories into a single
history. Default is False.
monotone: If True, the criterion plot becomes monotone in the sense that at each
iteration the current best criterion value is displayed.
Returns:
Tuple containing
- lines: Main optimization paths.
- multistart_lines: Multistart optimization paths.
"""
fun_or_monotone_fun = "monotone_fun" if monotone else "fun"
# Collect multistart optimization paths
multistart_lines: list[LineData] = []
plot_multistart = len(data) == 1 and data[0].is_multistart and not stack_multistart
if plot_multistart and data[0].local_histories:
for i, local_history in enumerate(data[0].local_histories):
history = getattr(local_history, fun_or_monotone_fun)
if max_evaluations is not None and len(history) > max_evaluations:
history = history[:max_evaluations]
line_data = LineData(
x=np.arange(len(history)),
y=history,
color="#bab0ac",
name=str(i),
show_in_legend=False,
)
multistart_lines.append(line_data)
# Collect main optimization paths
lines: list[LineData] = []
for _data in data:
if stack_multistart and _data.stacked_local_histories is not None:
_history = _data.stacked_local_histories
else:
_history = _data.history
history = getattr(_history, fun_or_monotone_fun)
if max_evaluations is not None and len(history) > max_evaluations:
history = history[:max_evaluations]
_color = next(palette_cycle)
if not isinstance(_color, str):
msg = "highlight_palette needs to be a string or list of strings, but its "
f"entry is of type {type(_color)}."
raise TypeError(msg)
line_data = LineData(
x=np.arange(len(history)),
y=history,
color=_color,
name="best result" if plot_multistart else _data.name,
show_in_legend=not plot_multistart,
)
lines.append(line_data)
return lines, multistart_lines
@dataclass(frozen=True)
class PlotConfig:
"""Configuration settings for figure.
Attributes:
template: The template for the figure.
legend: Configuration for the legend.
"""
template: str
legend: dict[str, Any]
def _plotly_line_plot(lines: list[LineData], plot_config: PlotConfig) -> go.Figure:
"""Create a plotly line plot from the given lines and plot configuration.
Args:
lines: Data for lines to be plotted.
plot_config: Configuration for the plot.
Returns:
The figure object containing the lines.
"""
fig = go.Figure()
for line in lines:
trace = go.Scatter(
x=line.x,
y=line.y,
name=line.name,
mode="lines",
line_color=line.color,
showlegend=line.show_in_legend,
connectgaps=True,
)
fig.add_trace(trace)
fig.update_layout(
template=plot_config.template,
xaxis_title_text="No. of criterion evaluations",
yaxis_title_text="Criterion value",
legend=plot_config.legend,
)
return fig