Source code for optimagic.visualization.history_plots

import inspect
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Literal

import numpy as np
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten

from optimagic.config import DEFAULT_PALETTE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history import History
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
from optimagic.typing import IterationHistory, PyTree
from optimagic.visualization.backends import line_plot
from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle

BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = {
    "plotly": {
        "yanchor": "top",
        "xanchor": "right",
        "y": 0.95,
        "x": 0.95,
    },
    "matplotlib": {
        "loc": "upper right",
    },
}


ResultOrPath = OptimizeResult | str | Path


[docs] def criterion_plot( results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath], names: list[str] | str | None = None, max_evaluations: int | None = None, backend: Literal["plotly", "matplotlib"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, stack_multistart: bool = False, monotone: bool = False, show_exploration: bool = False, ) -> Any: """Plot the criterion history of an optimization. Args: results: An optimization result (or list of, or dict of results) with collected history, or path(s) to it. If dict, then the key is used as the name in the legend. max_evaluations: Clip the criterion history after that many entries. backend: The backend to use for plotting. Default is "plotly". template: The template for the figure. If not specified, the default template of the backend is used. palette: The coloring palette for traces. Default is the D3 qualitative palette. stack_multistart: Whether to combine multistart histories into a single history. Default is False. monotone: If True, the criterion plot becomes monotone in the sense that at each iteration the current best criterion value is displayed. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. Returns: The figure object containing the criterion plot. """ # ================================================================================== # Process inputs palette_cycle = get_palette_cycle(palette) dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names) # ================================================================================== # Extract backend-agnostic plotting data from results list_of_optimize_data = _retrieve_optimization_data_from_results( results=dict_of_optimize_results_or_paths, stack_multistart=stack_multistart, show_exploration=show_exploration, plot_name="criterion_plot", ) lines, multistart_lines = _extract_criterion_plot_lines( data=list_of_optimize_data, max_evaluations=max_evaluations, palette_cycle=palette_cycle, stack_multistart=stack_multistart, monotone=monotone, ) # ================================================================================== # Generate the figure fig = line_plot( lines=multistart_lines + lines, backend=backend, xlabel="No. of criterion evaluations", ylabel="Criterion value", template=template, legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None), ) return fig
def _harmonize_inputs_to_dict( results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath], names: list[str] | str | None, ) -> dict[str, ResultOrPath]: """Convert all valid inputs for results and names to dict[str, OptimizeResult].""" # convert scalar case to list case if not isinstance(names, list) and names is not None: names = [names] if isinstance(results, (OptimizeResult, str, Path)): results = [results] if names is not None and len(names) != len(results): raise ValueError("len(results) needs to be equal to len(names).") # handle dict case if isinstance(results, dict): if names is not None: results_dict = dict(zip(names, list(results.values()), strict=False)) else: results_dict = results # unlabeled iterable of results else: if names is None: names = [str(i) for i in range(len(results))] results_dict = dict(zip(names, results, strict=False)) # convert keys to strings results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()} return results_dict def _convert_key_to_str(key: Any) -> str: if inspect.isclass(key) and issubclass(key, Algorithm): out = str(key.name) elif isinstance(key, Algorithm): out = str(key.name) else: out = str(key) return out
[docs] def params_plot( result: ResultOrPath, selector: Callable[[PyTree], PyTree] | None = None, max_evaluations: int | None = None, backend: Literal["plotly", "matplotlib"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, show_exploration: bool = False, ) -> Any: """Plot the params history of an optimization. Args: result: An optimization result with collected history, or path to it. If dict, then the key is used as the name in the legend. selector: A callable that takes params and returns a subset of params. If provided, only the selected subset of params is plotted. max_evaluations: Clip the criterion history after that many entries. backend: The backend to use for plotting. Default is "plotly". template: The template for the figure. If not specified, the default template of the backend is used. palette: The coloring palette for traces. Default is the D3 qualitative palette. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. Returns: The figure object containing the params plot. """ # ================================================================================== # Process inputs palette_cycle = get_palette_cycle(palette) # ================================================================================== # Extract backend-agnostic plotting data from results optimize_data = _retrieve_optimization_data_from_single_result( result=result, stack_multistart=True, show_exploration=show_exploration, plot_name="params_plot", ) lines = _extract_params_plot_lines( data=optimize_data, selector=selector, max_evaluations=max_evaluations, palette_cycle=palette_cycle, ) # ================================================================================== # Generate the figure fig = line_plot( lines=lines, backend=backend, xlabel="No. of criterion evaluations", ylabel="Parameter value", template=template, legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None), ) return fig
@dataclass(frozen=True) class _PlottingMultistartHistory: """Data container for an optimization history and metadata. Contains local histories in case of multistart optimization. This dataclass is only used internally. """ history: History name: str | None start_params: PyTree is_multistart: bool local_histories: list[History] | list[IterationHistory] | None stacked_local_histories: History | None def _retrieve_optimization_data_from_results( results: dict[str, ResultOrPath], stack_multistart: bool, show_exploration: bool, plot_name: str, ) -> list[_PlottingMultistartHistory]: # Retrieves data from multiple results by iterating over the results dictionary # and calling the single result retrieval function. data = [] for name, res in results.items(): _data = _retrieve_optimization_data_from_single_result( result=res, stack_multistart=stack_multistart, show_exploration=show_exploration, plot_name=plot_name, res_name=name, ) data.append(_data) return data def _retrieve_optimization_data_from_single_result( result: ResultOrPath, stack_multistart: bool, show_exploration: bool, plot_name: str, res_name: str | None = None, ) -> _PlottingMultistartHistory: """Retrieve data from a single result (OptimizeResult or database). Args: result: An optimization result with collected history, or path to it. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. plot_name: Name of the plotting function that calls this function. Used for raising errors. res_name: Name of the result. Returns: A data object containing the history, metadata, and local histories of the optimization result. """ if isinstance(result, OptimizeResult): data = _retrieve_optimization_data_from_result_object( res=result, stack_multistart=stack_multistart, show_exploration=show_exploration, plot_name=plot_name, res_name=res_name, ) elif isinstance(result, (str, Path)): data = _retrieve_optimization_data_from_database( res=result, stack_multistart=stack_multistart, show_exploration=show_exploration, res_name=res_name, ) else: msg = ( "result must be an OptimizeResult or a path to a log file, " f"but is type {type(result)}." ) raise TypeError(msg) return data def _retrieve_optimization_data_from_result_object( res: OptimizeResult, stack_multistart: bool, show_exploration: bool, plot_name: str, res_name: str | None = None, ) -> _PlottingMultistartHistory: """Retrieve optimization data from result object. Args: res: An optimization result object. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. plot_name: Name of the plotting function that calls this function. Used for raising errors. res_name: Name of the result. Returns: A data object containing the history, metadata, and local histories of the optimization result. """ if res.history is None: msg = f"{plot_name} requires an optimize result with history. Enable history " "collection by setting collect_history=True when calling maximize or minimize." raise ValueError(msg) if res.multistart_info: local_histories = [ opt.history for opt in res.multistart_info.local_optima if opt.history is not None ] if stack_multistart: stacked = _get_stacked_local_histories(local_histories, res.direction) if show_exploration: fun = res.multistart_info.exploration_results[::-1] + stacked.fun params = res.multistart_info.exploration_sample[::-1] + stacked.params stacked = History( direction=stacked.direction, fun=fun, params=params, # TODO: This needs to be fixed start_time=len(fun) * [None], # type: ignore stop_time=len(fun) * [None], # type: ignore batches=len(fun) * [None], # type: ignore task=len(fun) * [None], # type: ignore ) else: stacked = None else: local_histories = None stacked = None data = _PlottingMultistartHistory( history=res.history, name=res_name, start_params=res.start_params, is_multistart=res.multistart_info is not None, local_histories=local_histories, stacked_local_histories=stacked, ) return data def _retrieve_optimization_data_from_database( res: str | Path, stack_multistart: bool, show_exploration: bool, res_name: str | None = None, ) -> _PlottingMultistartHistory: """Retrieve optimization data from a database. Args: res: A path to an optimization database. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. res_name: Name of the result. Returns: A data object containing the history, metadata, and local histories of the optimization result. """ reader: LogReader = LogReader.from_options(SQLiteLogOptions(res)) _problem_table = reader.problem_df direction = _problem_table["direction"].tolist()[-1] multistart_history = reader.read_multistart_history(direction) _history = multistart_history.history local_histories = multistart_history.local_histories exploration = multistart_history.exploration if stack_multistart and local_histories is not None: stacked = _get_stacked_local_histories(local_histories, direction, _history) if show_exploration: stacked["params"] = exploration["params"][::-1] + stacked["params"] # type: ignore stacked["criterion"] = exploration["criterion"][::-1] + stacked["criterion"] # type: ignore else: stacked = None history = History( direction=direction, fun=_history["fun"], params=_history["params"], start_time=_history["time"], # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available. # https://github.com/optimagic-dev/optimagic/pull/553 stop_time=len(_history["fun"]) * [None], # type: ignore task=len(_history["fun"]) * [None], # type: ignore batches=list(range(len(_history["fun"]))), ) data = _PlottingMultistartHistory( history=history, name=res_name, start_params=reader.read_start_params(), is_multistart=local_histories is not None, local_histories=local_histories, stacked_local_histories=stacked, ) return data def _get_stacked_local_histories( local_histories: list[History] | list[IterationHistory], direction: Any, history: History | IterationHistory | None = None, ) -> History: """Stack local histories. Local histories is a list of dictionaries, each of the same structure. We transform this to a dictionary of lists. Finally, when the data is read from the database we append the best history at the end. """ stacked: dict[str, list[Any]] = {"criterion": [], "params": [], "runtime": []} for hist in local_histories: stacked["criterion"].extend(hist.fun) stacked["params"].extend(hist.params) stacked["runtime"].extend(hist.time) # append additional history is necessary if history is not None: stacked["criterion"].extend(history.fun) stacked["params"].extend(history.params) stacked["runtime"].extend(history.time) return History( direction=direction, fun=stacked["criterion"], params=stacked["params"], start_time=stacked["runtime"], # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available for the # IterationHistory. # https://github.com/optimagic-dev/optimagic/pull/553 stop_time=len(stacked["criterion"]) * [None], # type: ignore task=len(stacked["criterion"]) * [None], # type: ignore batches=list(range(len(stacked["criterion"]))), ) def _extract_criterion_plot_lines( data: list[_PlottingMultistartHistory], max_evaluations: int | None, palette_cycle: "itertools.cycle[str]", stack_multistart: bool, monotone: bool, ) -> tuple[list[LineData], list[LineData]]: """Extract lines for criterion plot from data. Args: data: Data retrieved from results or database. max_evaluations: Clip the criterion history after that many entries. palette_cycle: Cycle of colors for plotting. stack_multistart: Whether to combine multistart histories into a single history. Default is False. monotone: If True, the criterion plot becomes monotone in the sense that at each iteration the current best criterion value is displayed. Returns: Tuple containing - lines: Main optimization paths. - multistart_lines: Multistart optimization paths. """ fun_or_monotone_fun = "monotone_fun" if monotone else "fun" # Collect multistart optimization paths multistart_lines: list[LineData] = [] plot_multistart = len(data) == 1 and data[0].is_multistart and not stack_multistart if plot_multistart and data[0].local_histories: for i, local_history in enumerate(data[0].local_histories): history = getattr(local_history, fun_or_monotone_fun) if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] line_data = LineData( x=np.arange(len(history)), y=history, color="#bab0ac", name=str(i), show_in_legend=False, ) multistart_lines.append(line_data) # Collect main optimization paths lines: list[LineData] = [] for _data in data: if stack_multistart and _data.stacked_local_histories is not None: _history = _data.stacked_local_histories else: _history = _data.history history = getattr(_history, fun_or_monotone_fun) if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] line_data = LineData( x=np.arange(len(history)), y=history, color=next(palette_cycle), name="best result" if plot_multistart else _data.name, show_in_legend=not plot_multistart, ) lines.append(line_data) return lines, multistart_lines def _extract_params_plot_lines( data: _PlottingMultistartHistory, selector: Callable[[PyTree], PyTree] | None, max_evaluations: int | None, palette_cycle: "itertools.cycle[str]", ) -> list[LineData]: """Extract lines for params plot from data. Args: data: Data retrieved from results or database. selector: A callable that takes params and returns a subset of params. If provided, only the selected subset of params is plotted. max_evaluations: Clip the criterion history after that many entries. palette_cycle: Cycle of colors for plotting. Returns: lines: Parameter histories. """ if data.stacked_local_histories is not None: history = data.stacked_local_histories.params else: history = data.history.params start_params = data.start_params registry = get_registry(extended=True) hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T names = leaf_names(start_params, registry=registry) if selector is not None: flat, treedef = tree_flatten(start_params, registry=registry) helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry) selected = np.array(tree_just_flatten(selector(helper), registry=registry)) names = [names[i] for i in selected] hist_arr = hist_arr[selected] lines: list[LineData] = [] for name, _data in zip(names, hist_arr, strict=False): if max_evaluations is not None and len(_data) > max_evaluations: plot_data = _data[:max_evaluations] else: plot_data = _data line_data = LineData( x=np.arange(len(plot_data)), y=plot_data, color=next(palette_cycle), name=name, show_in_legend=True, ) lines.append(line_data) return lines