Source code for ax.api.client

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import json
from collections.abc import Iterable, Sequence
from logging import Logger
from typing import Any, cast, Literal, Self

import numpy as np
import pandas as pd
from ax.analysis.analysis import Analysis, display_cards
from ax.analysis.overview import OverviewAnalysis
from ax.analysis.summary import Summary
from ax.api.configs import (
    ChoiceParameterConfig,
    DerivedParameterConfig,
    RangeParameterConfig,
    StorageConfig,
)
from ax.api.protocols.metric import IMetric
from ax.api.protocols.runner import IRunner
from ax.api.types import TOutcome, TParameterization
from ax.api.utils.generation_strategy_dispatch import choose_generation_strategy
from ax.api.utils.instantiation.from_string import optimization_config_from_string
from ax.api.utils.instantiation.from_struct import experiment_from_struct
from ax.api.utils.storage import db_settings_from_storage_config
from ax.api.utils.structs import ExperimentStruct, GenerationStrategyDispatchStruct
from ax.core.analysis_card import AnalysisCardBase
from ax.core.arm import Arm
from ax.core.experiment import Experiment
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus  # Used as a return type
from ax.core.types import TParameterization as CoreTParameterization
from ax.early_stopping.strategies import (
    BaseEarlyStoppingStrategy,
    PercentileEarlyStoppingStrategy,
)
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.storage.json_store.decoder import (
    generation_strategy_from_json,
    object_from_json,
)
from ax.storage.json_store.encoder import object_to_json
from ax.storage.json_store.registry import (
    CORE_CLASS_DECODER_REGISTRY,
    CORE_CLASS_ENCODER_REGISTRY,
    CORE_DECODER_REGISTRY,
    CORE_ENCODER_REGISTRY,
)
from ax.storage.sqa_store.with_db_settings_base import WithDBSettingsBase
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.random import with_rng_seed
from pyre_extensions import assert_is_instance, none_throws

logger: Logger = get_logger(__name__)
ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES: int = 6


[docs] class Client(WithDBSettingsBase): _maybe_experiment: Experiment | None = None _maybe_generation_strategy: GenerationStrategy | None = None _maybe_early_stopping_strategy: BaseEarlyStoppingStrategy | None = None def __init__( self, storage_config: StorageConfig | None = None, random_seed: int | None = None, ) -> None: """ Initialize a ``Client``, which manages state across the lifecycle of an experiment. Args: storage_config: Configuration for saving to and loading from a database. If elided the experiment will not automatically be saved to a database. random_seed: An optional integer to set the random seed for reproducibility of the experiment's results. If not provided, the random seed will not be set, leading to potentially different results on different runs. """ super().__init__( # Initialize WithDBSettingsBase db_settings=db_settings_from_storage_config(storage_config=storage_config) if storage_config is not None else None, ) self._storage_config = storage_config self._random_seed = random_seed # -------------------- Section 1: Configure -------------------------------------
[docs] def configure_experiment( self, parameters: Sequence[ RangeParameterConfig | ChoiceParameterConfig | DerivedParameterConfig ], parameter_constraints: Sequence[str] | None = None, name: str | None = None, description: str | None = None, experiment_type: str | None = None, owner: str | None = None, ) -> None: """ Given an ``ExperimentConfig``, construct the Ax ``Experiment`` object. Note that validation occurs at time of config instantiation, not at ``configure_experiment``. This method only constitutes defining the search space and misc. metadata like name, description, and owners. Saves to database on completion if ``storage_config`` is present. """ if self._maybe_experiment is not None: raise UnsupportedError( "Experiment already configured. Please create a new Client if you " "would like a new experiment." ) experiment_struct = ExperimentStruct( parameters=[*parameters], parameter_constraints=[*parameter_constraints] if parameter_constraints else [], name=name, description=description, experiment_type=experiment_type, owner=owner, ) self._maybe_experiment = experiment_from_struct(struct=experiment_struct) self._save_experiment_to_db_if_possible(experiment=self._experiment)
[docs] def configure_optimization( self, objective: str, outcome_constraints: Sequence[str] | None = None, pruning_target_parameterization: TParameterization | None = None, ) -> None: """ Configures the goals of the optimization by setting the ``OptimizationConfig``. ``Metrics`` referenced here by their name will be moved from the Experiment's ``tracking_metrics`` if they were were already present (i.e. they were attached via ``configure_metrics``) or added as base ``Metrics``. Args: objective: Objective is a string and allows us to express single, scalarized, and multi-objective goals. Ex: "loss", "ne1 + 2 * ne2", "-ne, qps" outcome_constraints: Outcome constraints are also strings and allow us to express a desire to have a metric clear a threshold but not be further optimized. These constraints are expressed as inequalities. Ex: "qps >= 100", "0.5 * ne1 + 0.5 * ne2 >= 0.95". To indicate a relative constraint multiply your bound by "baseline" Ex: "qps >= 0.95 * baseline" will constrain such that the QPS is at least 95% of the baseline arm's QPS. Note that scalarized outcome constraints cannot be relative. pruning_target_parameterization: Parameterization containing the target values for irrelevant parameters. The target values are used to prune irrelevant parameters from candidates generated via Bayesian optimization: when Ax considers arms to suggest for the next trial, it will seek to have the proposed arms differ from the target arm as little as possible, without a loss in optimization performance. If not specified and a status_quo is set on the experiment, it will be used as the pruning target. Must be a valid member of the search space. Saves to database on completion if ``storage_config`` is present. """ # Validate and convert the pruning_target_parameterization to an Arm if # provided pruning_target_arm: Arm | None = None if pruning_target_parameterization is not None: self._experiment.search_space.validate_membership( parameters=cast(CoreTParameterization, pruning_target_parameterization) ) pruning_target_arm = Arm( parameters=pruning_target_parameterization, name="pruning_target" ) old_metrics = self._experiment.metrics optimization_config = optimization_config_from_string( objective_str=objective, outcome_constraint_strs=outcome_constraints, ) # Set the pruning_target_parameterization on the optimization config if # provided if pruning_target_arm is not None: optimization_config.pruning_target_parameterization = pruning_target_arm # Register any new metrics from the optimization config on the experiment # before setting the optimization config, which validates that all # referenced metrics exist on the experiment. # Determine lower_is_better for objective metrics from weights. obj = optimization_config.objective objective_lower_is_better: dict[str, bool] = {} # metric_weights returns (signature, weight) tuples; use metric_names # to key by name since downstream lookups use metric names. obj_names = obj.metric_names obj_weights = [w for _, w in obj.metric_weights] for name, weight in zip(obj_names, obj_weights): objective_lower_is_better[name] = weight < 0 # Register objective metrics first (preserving expression order), # then constraint metrics. all_metric_names: list[str] = list(obj.metric_names) for oc in optimization_config.all_constraints: for name in oc.metric_names: if name not in all_metric_names: all_metric_names.append(name) for metric_name in all_metric_names: if metric_name not in self._experiment.metrics: lower_is_better = objective_lower_is_better.get(metric_name, None) self._experiment.add_tracking_metric( MapMetric(name=metric_name, lower_is_better=lower_is_better) ) self._experiment.optimization_config = optimization_config self._set_metrics(metrics=list(old_metrics.values())) self._save_experiment_to_db_if_possible(experiment=self._experiment)
[docs] def configure_generation_strategy( self, method: Literal["quality", "fast", "random_search"] = "fast", # Initialization options initialization_budget: int | None = None, initialization_random_seed: int | None = None, initialize_with_center: bool = True, use_existing_trials_for_initialization: bool = True, min_observed_initialization_trials: int | None = None, allow_exceeding_initialization_budget: bool = False, # Misc options torch_device: str | None = None, simplify_parameter_changes: bool = False, ) -> None: """ Optional method to configure the way candidate parameterizations are generated during the optimization; if not called a default ``GenerationStrategy`` will be used. Saves to database on completion if ``storage_config`` is present. Args: method: The method to use for generating candidates. Options are: - "fast": Uses Bayesian optimization, configured specifically for the current experiment. - "random_search": Uses random search. initialization_budget: Number of initialization trials. If None, will be automatically determined based on the search space. initialization_random_seed: Random seed for initialization. If None, no seed will be set. initialize_with_center: Whether to include the center of the search space in the initialization trials. use_existing_trials_for_initialization: Whether to use existing trials for initialization. min_observed_initialization_trials: Minimum number of observed init trials required before moving to the next generation step. allow_exceeding_initialization_budget: Whether to allow exceeding the initialization budget if more trials are needed. torch_device: The torch device to use for model fitting. If None, will use the default device. simplify_parameter_changes: Whether to use BONSAI [Daulton2026bonsai]_ to simplify parameter changes in arms generated via Bayesian Optimization by pruning irrelevant parameter changes. """ generation_strategy = self._choose_generation_strategy( method=method, initialization_budget=initialization_budget, initialization_random_seed=initialization_random_seed, initialize_with_center=initialize_with_center, use_existing_trials_for_initialization=use_existing_trials_for_initialization, # noqa[E501] min_observed_initialization_trials=min_observed_initialization_trials, allow_exceeding_initialization_budget=allow_exceeding_initialization_budget, torch_device=torch_device, simplify_parameter_changes=simplify_parameter_changes, ) self.set_generation_strategy(generation_strategy=generation_strategy)
# -------------------- Section 1.1: Configure Automation ------------------------
[docs] def configure_runner(self, runner: IRunner) -> None: """ Attaches a ``Runner`` to the ``Experiment``, to be used for automating trial deployment when using ``run_n_trials``. Saves to database on completion if ``storage_config`` is present. """ self._set_runner(runner=runner)
[docs] def configure_metrics(self, metrics: Sequence[IMetric]) -> None: """ Attach a ``Metric`` with logic for autmating fetching of a given metric by replacing its instance with the provided ``Metric`` from metrics sequence input, or adds the ``Metric`` provided to the ``Experiment`` as a tracking metric if that metric was not already present. """ self._set_metrics(metrics=metrics)
[docs] def configure_tracking_metrics(self, metric_names: Sequence[str]) -> None: """ Add tracking metrics to the ``Experiment`` by name. Tracking metrics are metrics that are recorded during the experiment but are not used as part of the ``OptimizationConfig`` (i.e., they are not objectives or outcome constraints). Use this method to declare metrics that you want to track alongside your optimization objectives. If any of the metrics are already defined on the experiment, they will be skipped with a warning. Args: metric_names: Names of metrics to be added as tracking metrics. Saves to database on completion if ``storage_config`` is present. """ for metric_name in metric_names: if metric_name in self._experiment.metrics: logger.warning( f"Metric {metric_name} already exists on experiment, skipping." ) continue self._experiment.add_tracking_metric(metric=Metric(name=metric_name)) self._save_experiment_to_db_if_possible(experiment=self._experiment)
# -------------------- Section 1.2: Set (not API) -------------------------------
[docs] def set_experiment(self, experiment: Experiment) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers, power users, and partners. Overwrite the existing ``Experiment`` with the provided ``Experiment``. Saves to database on completion if ``storage_config`` is present. """ self._maybe_experiment = experiment self._save_experiment_to_db_if_possible(experiment=self._experiment)
[docs] def set_optimization_config(self, optimization_config: OptimizationConfig) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers, power users, and partners. Overwrite the existing ``OptimizationConfig`` with the provided ``OptimizationConfig``. Saves to database on completion if ``storage_config`` is present. """ self._experiment.optimization_config = optimization_config self._save_experiment_to_db_if_possible(experiment=self._experiment)
[docs] def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers, power users, and partners. Overwrite the existing ``GenerationStrategy`` with the provided ``GenerationStrategy``. Saves to database on completion if ``storage_config`` is present. """ self._maybe_generation_strategy = generation_strategy self._generation_strategy._experiment = self._experiment self._save_generation_strategy_to_db_if_possible( generation_strategy=self._generation_strategy )
[docs] def set_early_stopping_strategy( self, early_stopping_strategy: BaseEarlyStoppingStrategy ) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers, power users, and partners. Overwrite the existing ``EarlyStoppingStrategy`` with the provided ``EarlyStoppingStrategy``. Saves to database on completion if ``storage_config`` is present. """ self._maybe_early_stopping_strategy = early_stopping_strategy
def _set_runner(self, runner: Runner) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers and power users. Attaches a ``Runner`` to the ``Experiment``. Saves to database on completion if ``storage_config`` is present. """ self._experiment.runner = runner self._update_runner_on_experiment_in_db_if_possible( experiment=self._experiment, runner=runner ) def _set_metrics(self, metrics: Sequence[Metric]) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers, power users, and partners. Attach a ``Metric`` with logic for autmating fetching of a given metric by replacing its instance with the provided ``Metric`` from metrics sequence input, or adds the ``Metric`` provided to the Experiment as a tracking metric if that metric was not already present. Saves to database on completion if ``storage_config`` is present. """ # If an equivalently named Metric already exists on the Experiment, replace it # with the Metric provided. Otherwise, add the Metric to the Experiment as a # tracking metric. for metric in metrics: # Check the optimization config first self._overwrite_metric(metric=metric) self._save_experiment_to_db_if_possible(experiment=self._experiment) # -------------------- Section 2. Conduct Experiment ----------------------------
[docs] def get_next_trials( self, max_trials: int, fixed_parameters: TParameterization | None = None, ) -> dict[int, TParameterization]: """ Create up to ``max_trials`` trials using the ``GenerationStrategy`` (or as many as possible before reaching the maximum parellelism defined by the ``GenerationNode``), attach them to the ``Experiment`` with status RUNNING, and return a mapping from trial index to its parameterization. If a partial parameterization is provided via ``fixed_parameters`` each parameterization will have those parameters set to the provided values. Saves to database on completion if ``storage_config`` is present. Returns: A mapping of trial index to parameterization. """ if self._experiment.optimization_config is None: raise UnsupportedError( "OptimizationConfig not set. Please call configure_optimization before " "generating trials." ) trials: list[Trial] = [] with with_rng_seed(seed=self._random_seed): grs_for_trials = self._generation_strategy_or_choose().gen( experiment=self._experiment, n=1, fixed_features=( ObservationFeatures( parameters=cast(CoreTParameterization, fixed_parameters) ) if fixed_parameters is not None else None ), num_trials=max_trials, ) for trial_grs in grs_for_trials: assert len(trial_grs) == 1 trial = self._experiment.new_trial(generator_run=trial_grs[0]) logger.info( f"Generated new trial {trial.index} with parameters " + str( _round_floats_for_logging( item=none_throws(trial.arm).parameters, decimal_places=ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES, ) ) + f" using GenerationNode {trial_grs[0]._generation_node_name}." ) trial.mark_running(no_runner_required=True) trials.append(trial) if len(trials) < max_trials: logger.warning( f"{max_trials} trials requested but only {len(trials)} could be " "generated." ) # Save GS to db self._save_generation_strategy_to_db_if_possible( generation_strategy=self._generation_strategy ) # Bulk save all trials to the database if possible self._save_or_update_trials_in_db_if_possible( experiment=self._experiment, trials=trials ) return { trial.index: cast(TParameterization, none_throws(trial.arm).parameters) for trial in trials }
[docs] def complete_trial( self, trial_index: int, raw_data: TOutcome | None = None, progression: int | None = None, ) -> TrialStatus: """ Indicate the trial is complete and optionally attach data. The trial is always marked ``COMPLETED`` regardless of which metrics are present. Data availability is tracked separately from trial orchestration status — partial data is still used for modeling. Saves to database on completion if ``storage_config`` is present. """ if raw_data is not None: self.attach_data( trial_index=trial_index, raw_data=raw_data, progression=progression ) self._experiment.trials[trial_index].mark_completed() # Log metric availability for user visibility. if (optimization_config := self._experiment.optimization_config) is not None: trial_data = self._experiment.lookup_data(trial_indices=[trial_index]) missing_metrics = optimization_config.metric_names - { *trial_data.metric_names } if len(missing_metrics) == 0: logger.info(f"Trial {trial_index} marked COMPLETED.") else: logger.warning( f"Trial {trial_index} marked COMPLETED but missing optimization " f"config metrics: {missing_metrics}. " "Partial data will still be used for modeling." ) self._save_or_update_trial_in_db_if_possible( experiment=self._experiment, trial=self._experiment.trials[trial_index] ) return self._experiment.trials[trial_index].status
[docs] def attach_data( self, trial_index: int, raw_data: TOutcome, progression: int | None = None, ) -> None: """ Attach data without indicating the trial is complete. Missing metrics are allowed, and unexpected metric values will be added to the Experiment as tracking metrics. Saves to database on completion if ``storage_config`` is present. """ # Auto-register any metrics present in raw_data but not yet on the # experiment as tracking metrics, matching the docstring contract. extra_metrics = set(raw_data.keys()) - set(self._experiment.metrics.keys()) if extra_metrics: self.configure_tracking_metrics(metric_names=list(extra_metrics)) # If no progression is provided assume the data is not timeseries-like and # set step=NaN data_with_progression = [ (progression if progression is not None else np.nan, raw_data) ] trial = assert_is_instance(self._experiment.trials[trial_index], Trial) trial.update_trial_data(raw_data=data_with_progression) self._save_or_update_trial_in_db_if_possible( experiment=self._experiment, trial=trial )
# -------------------- Section 2.1 Custom trials --------------------------------
[docs] def attach_trial( self, parameters: TParameterization, arm_name: str | None = None ) -> int: """ Attach a single-arm trial to the ``Experiment`` with the provided parameters. The trial will be marked as RUNNING and must be completed manually by the user. Saves to database on completion if ``storage_config`` is present. Returns: The index of the attached trial. """ _, trial_index = self._experiment.attach_trial( parameterizations=[cast(CoreTParameterization, parameters)], arm_names=[arm_name] if arm_name else None, ) self._save_or_update_trial_in_db_if_possible( experiment=self._experiment, trial=self._experiment.trials[trial_index] ) return trial_index
[docs] def attach_baseline( self, parameters: TParameterization, arm_name: str | None = None ) -> int: """ Attaches custom single-arm trial to an ``Experiment`` specifically for use as the baseline or status quo in evaluating relative outcome constraints and improvement over baseline objective value. The trial will be marked as RUNNING and must be completed manually by the user. Returns: The index of the attached trial. Saves to database on completion if ``storage_config`` is present. """ trial_index = self.attach_trial( parameters=parameters, arm_name=arm_name or "baseline", ) self._experiment.status_quo = assert_is_instance( self._experiment.trials[trial_index], Trial ).arm self._save_experiment_to_db_if_possible(experiment=self._experiment) return trial_index
# -------------------- Section 2.2 Early Stopping -------------------------------
[docs] def should_stop_trial_early(self, trial_index: int) -> bool: """ Check if the trial should be stopped early. If True and the user wishes to heed Ax's recommendation the user should manually stop the trial and call ``mark_trial_early_stopped(trial_index)``. The ``EarlyStoppingStrategy`` may be selected automatically or set manually via ``set_early_stopping_strategy``. Returns: Whether the trial should be stopped early. """ es_response = none_throws( self._early_stopping_strategy_or_choose() ).should_stop_trials_early( trial_indices={trial_index}, experiment=self._experiment, current_node=self._generation_strategy_or_choose()._curr, ) if trial_index in es_response: logger.info( f"Trial {trial_index} should be stopped early: " f"{es_response[trial_index]}" ) return True return False
# -------------------- Section 2.3 Marking trial status manually ----------------
[docs] def mark_trial_failed( self, trial_index: int, failed_reason: str | None = None ) -> None: """ Manually mark a trial as FAILED. FAILED trials may be re-suggested by ``get_next_trials``. Saves to database on completion if ``storage_config`` is present. """ self._experiment.trials[trial_index].mark_failed(reason=failed_reason) logger.info(f"Trial {trial_index} marked FAILED.") self._save_or_update_trial_in_db_if_possible( experiment=self._experiment, trial=self._experiment.trials[trial_index] )
[docs] def mark_trial_abandoned(self, trial_index: int) -> None: """ Manually mark a trial as ABANDONED. ABANDONED trials are not able to be re-suggested by ``get_next_trials``. Saves to database on completion if ``storage_config`` is present. """ self._experiment.trials[trial_index].mark_abandoned() logger.info( f"Trial {trial_index} marked ABANDONED. ABANDONED trials are not able to " "be re-suggested by ``get_next_trials``." ) self._save_or_update_trial_in_db_if_possible( experiment=self._experiment, trial=self._experiment.trials[trial_index] )
[docs] def mark_trial_early_stopped(self, trial_index: int) -> None: """ Manually mark a trial as EARLY_STOPPED. This is used when the user has decided (with or without Ax's recommendation) to stop the trial after some data has been attached but before the trial is completed. Note that if data has not been attached for the trial yet users should instead call ``mark_trial_abandoned``. EARLY_STOPPED trials will not be re-suggested by ``get_next_trials``. Saves to database on completion if ``storage_config`` is present. """ self._experiment.trials[trial_index].mark_early_stopped() logger.info(f"Trial {trial_index} marked EARLY_STOPPED.") self._save_or_update_trial_in_db_if_possible( experiment=self._experiment, trial=self._experiment.trials[trial_index] )
[docs] def run_trials( self, max_trials: int, parallelism: int = 1, tolerated_trial_failure_rate: float = 0.5, initial_seconds_between_polls: int = 1, ) -> None: """ Run maximum_trials trials in a loop by creating an ephemeral Orchestrator under the hood using the Experiment, GenerationStrategy, Metrics, and Runner attached to this AxClient along with the provided OrchestrationConfig. Saves to database on completion if ``storage_config`` is present. """ orchestrator = Orchestrator( experiment=self._experiment, generation_strategy=self._generation_strategy_or_choose(), options=OrchestratorOptions( max_pending_trials=parallelism, tolerated_trial_failure_rate=tolerated_trial_failure_rate, init_seconds_between_polls=initial_seconds_between_polls, ), db_settings=db_settings_from_storage_config(self._storage_config) if self._storage_config is not None else None, ) # Note: This Orchestrator call will handle storage internally orchestrator.run_n_trials(max_trials=max_trials)
# -------------------- Section 3. Analyze ---------------------------------------
[docs] def compute_analyses( self, analyses: Sequence[Analysis] | None = None, display: bool = True, ) -> list[AnalysisCardBase]: """ Compute ``AnalysisCards`` (data about the optimization for end-user consumption) using the ``Experiment`` and ``GenerationStrategy``. If no analyses are provided use some heuristic to determine which analyses to run. If some analyses fail, log failure and continue to compute the rest. Note that the Analysis class is NOT part of the API and its methods are subject to change incompatibly between minor versions. Users are encouraged to use the provided analyses or leave this argument as ``None`` to use the default analyses. Saves cards to database on completion if ``storage_config`` is present. Args: analyses: A list of Analysis classes to run. If None Ax will choose which analyses to run based on the state of the experiment. display: Whether to display the AnalysisCards if executed in an interactive environment (e.g. Jupyter). Defaults to True. If not in an interactive environment this setting has no effect. Returns: A list of AnalysisCards. """ analyses = analyses if analyses is not None else [OverviewAnalysis()] # Compute Analyses. If any fails to compute, catch and instead return an # ErrorAnalysisCard which contains the Exception and its associated traceback. cards = [ analysis.compute_or_error_card( experiment=self._experiment, generation_strategy=self._generation_strategy, ) for analysis in analyses ] # Display the AnalysisCards if requested and if the user is in a notebook if display: display_cards(cards=cards) return cards
[docs] def summarize( self, trial_indices: Iterable[int] | None = None, trial_statuses: Sequence[ Literal[ "candidate", "running", "failed", "completed", "abandoned", "early_stopped", "staged", ] ] | None = None, ) -> pd.DataFrame: """ Special convenience method for producing the ``DataFrame`` produced by the ``Summary`` ``Analysis``. This method is a convenient way to inspect the state of the ``Experiment``, but because the shape of the resultant DataFrame can change based on the ``Experiment`` state both users and Ax developers should prefer to use other methods for extracting information from the experiment to consume downstream. The ``DataFrame`` computed will contain one row per arm and the following columns (though empty columns are omitted): - trial_index: The trial index of the arm - arm_name: The name of the arm - trial_status: The status of the trial (e.g. RUNNING, SUCCEEDED, FAILED) - failure_reason: The reason for the failure, if applicable - generation_node: The name of the ``GenerationNode`` that generated the arm - **METADATA: Any metadata associated with the trial, as specified by the Experiment's ``runner.run_metadata_report_keys`` field - **METRIC_NAME: The observed mean of the metric specified, for each metric - **PARAMETER_NAME: The parameter value for the arm, for each parameter Args: trial_indices: If specified, only include these trial indices. trial_status: If specified, only include trials with this status. """ # Convert string literals to TrialStatus enum values enum_trial_statuses = None if trial_statuses is not None: enum_trial_statuses = [ TrialStatus[status.upper()] for status in trial_statuses ] card = Summary( trial_indices=trial_indices, trial_statuses=enum_trial_statuses, omit_empty_columns=True, ).compute( experiment=self._experiment, generation_strategy=self._maybe_generation_strategy, ) return card.df
[docs] def get_best_parameterization( self, use_model_predictions: bool = True ) -> tuple[TParameterization, TOutcome, int, str]: """ Identifies the best parameterization tried in the experiment so far, also called the best in-sample arm. If ``use_model_predictions`` is ``True``, first attempts to do so with the model used in optimization and its corresponding predictions if available. If ``use_model_predictions`` is ``False`` or attempts to use the model fails, falls back to the best raw objective based on the data fetched from the ``Experiment``. Parameterizations which were observed to violate outcome constraints are not eligible to be the best parameterization. Returns: - The parameters predicted to have the best optimization value without violating any outcome constraints. - The metric values for the best parameterization. Uses model prediction if ``use_model_predictions=True``, otherwise returns observed data. - The trial which most recently ran the best parameterization - The name of the best arm (each trial has a unique name associated with each parameterization) """ if self._experiment.optimization_config is None: raise UnsupportedError( "No optimization config has been set. Please configure the " "optimization before calling get_best_parameterization." ) if self._experiment.optimization_config.is_moo_problem: raise UnsupportedError( "The client is currently configured to jointly optimize " f"{self._experiment.optimization_config}. " "Multi-objective optimization does not return a single best " "parameterization -- it returns a Pareto frontier. Please call " "get_pareto_frontier instead." ) if len(self._experiment.trials) < 1: raise UnsupportedError( "No trials have been run yet. Please run at least one trial before " "calling get_best_parameterization." ) # Note: Using BestPointMixin directly instead of inheriting to avoid exposing # unwanted public methods trial_index, parameterization, model_prediction = none_throws( BestPointMixin._get_best_trial( experiment=self._experiment, generation_strategy=self._generation_strategy_or_choose(), use_model_predictions=use_model_predictions, ) ) return cast( tuple[TParameterization, TOutcome, int, str], BestPointMixin._to_best_point_tuple( experiment=self._experiment, trial_index=trial_index, parameterization=parameterization, model_prediction=model_prediction, ), )
[docs] def get_pareto_frontier( self, use_model_predictions: bool = True ) -> list[tuple[TParameterization, TOutcome, int, str]]: """ Identifies the parameterizations which are predicted to efficiently trade-off between all objectives in a multi-objective optimization, also called the in-sample Pareto frontier. Returns: A list of tuples containing: - The parameters predicted to have the best optimization value without violating any outcome constraints. - The metric values for the best parameterization. Uses model prediction if ``use_model_predictions=True``, otherwise returns observed data. - The trial which most recently ran the best parameterization - The name of the best arm (each trial has a unique name associated with each parameterization). """ if self._experiment.optimization_config is None: raise UnsupportedError( "No optimization config has been set. Please configure the " "optimization before calling get_pareto_frontier." ) if not self._experiment.optimization_config.is_moo_problem: raise UnsupportedError( "The client is currently configured to optimize " f"{self._experiment.optimization_config.objective}. " "Single-objective optimization does not return a Pareto frontier -- " "it returns a single best point. Please call " "get_best_parameterization instead." ) if len(self._experiment.trials) < 1: raise UnsupportedError( "No trials have been run yet. Please run at least one trial before " "calling get_pareto_frontier." ) frontier = BestPointMixin._get_pareto_optimal_parameters( experiment=self._experiment, # Requiring true GenerationStrategy here, ideally we will loosen this # in the future generation_strategy=self._generation_strategy, use_model_predictions=use_model_predictions, ) return [ cast( tuple[TParameterization, TOutcome, int, str], BestPointMixin._to_best_point_tuple( experiment=self._experiment, trial_index=trial_index, parameterization=parameterization, model_prediction=model_prediction, ), ) for trial_index, (parameterization, model_prediction) in frontier.items() ]
[docs] def predict( self, points: Sequence[TParameterization], ) -> list[dict[str, tuple[float, float]]]: """ Use the current surrogate model to predict the outcome of the provided list of parameterizations. Returns: A list of mappings from metric name to predicted mean and SEM """ for parameters in points: self._experiment.search_space.check_membership( parameterization=parameters, raise_error=True, check_all_parameters_present=True, ) try: mean, covariance = none_throws(self._generation_strategy.adapter).predict( observation_features=[ ObservationFeatures( parameters=cast(CoreTParameterization, parameters) ) for parameters in points ] ) except (NotImplementedError, AssertionError) as e: raise UnsupportedError( "Predicting with the GenerationStrategy's adapter failed. This " "could be because the current GenerationNode is not predictive -- try " "running more trials to progress to a predictive GenerationNode." ) from e return [ { metric_name: ( mean[metric_name][i], covariance[metric_name][metric_name][i] ** 0.5, ) for metric_name in mean.keys() } for i in range(len(points)) ]
# -------------------- Section 4: Save/Load ------------------------------------- # Note: SQL storage handled automatically during regular usage
[docs] def save_to_json_file(self, filepath: str = "ax_client_snapshot.json") -> None: """ Save a JSON-serialized snapshot of this ``Client``'s settings and state to a .json file by the given path. """ with open(filepath, "w+") as file: file.write(json.dumps(self._to_json_snapshot())) logger.debug( f"Saved JSON-serialized state of optimization to `{filepath}`." )
[docs] @classmethod def load_from_json_file( cls, filepath: str = "ax_client_snapshot.json", storage_config: StorageConfig | None = None, ) -> Self: """ Restore a ``Client`` and its state from a JSON-serialized snapshot, residing in a .json file by the given path. Returns: The restored ``Client``. """ with open(filepath) as file: return cls._from_json_snapshot( snapshot=json.loads(file.read()), storage_config=storage_config )
[docs] @classmethod def load_from_database( cls, experiment_name: str, storage_config: StorageConfig | None = None, ) -> Self: """ Restore an ``Client`` and its state from database by the given name. Returns: The restored ``Client``. """ db_settings_base = WithDBSettingsBase( db_settings=db_settings_from_storage_config(storage_config=storage_config) if storage_config is not None else None ) maybe_experiment, maybe_generation_strategy = ( db_settings_base._load_experiment_and_generation_strategy( experiment_name=experiment_name ) ) if (experiment := maybe_experiment) is None: raise ObjectNotFoundError( f"Experiment {experiment_name} not found in database. Please check " "its name is correct, check your StorageConfig is correct, or create " "a new experiment." ) client = cls(storage_config=storage_config) client.set_experiment(experiment=experiment) if maybe_generation_strategy is not None: client.set_generation_strategy( generation_strategy=maybe_generation_strategy ) return client
# -------------------- Section 5: Private Methods ------------------------------- # -------------------- Section 5.1: Getters and defaults ------------------------ @property def _experiment(self) -> Experiment: return none_throws( self._maybe_experiment, ( "Experiment not set. Please call configure_experiment or load an " "experiment before utilizing any other methods on the Client." ), ) @property def _generation_strategy(self) -> GenerationStrategy: return none_throws( self._maybe_generation_strategy, ( "GenerationStrategy not set. Please call " "configure_generation_strategy, load a GenerationStrategy, or call " "get_next_trials or run_trials to automatically choose a " "GenerationStrategy before utilizing any other methods on the Client " "which require one." ), ) @property def _early_stopping_strategy(self) -> BaseEarlyStoppingStrategy: return none_throws( self._maybe_early_stopping_strategy, "Early stopping strategy not set. Please set an early stopping strategy " "before calling should_stop_trial_early.", ) def _generation_strategy_or_choose( self, ) -> GenerationStrategy: """ If a GenerationStrategy is not set, choose a default one (save to database) and return it. """ try: return self._generation_strategy except AssertionError: self.configure_generation_strategy() return self._generation_strategy def _early_stopping_strategy_or_choose( self, ) -> BaseEarlyStoppingStrategy: """ If an EarlyStoppingStrategy is not set choose a default one and return it. """ try: return self._early_stopping_strategy except AssertionError: # PercetinleEarlyStoppingStrategy may or may not have sensible defaults at # current moment -- we will need to be critical of these settings during # benchmarking self.set_early_stopping_strategy( early_stopping_strategy=PercentileEarlyStoppingStrategy() ) return self._early_stopping_strategy def _choose_generation_strategy( self, method: Literal["quality", "fast", "random_search"] = "fast", # Initialization options initialization_budget: int | None = None, initialization_random_seed: int | None = None, initialize_with_center: bool = True, use_existing_trials_for_initialization: bool = True, min_observed_initialization_trials: int | None = None, allow_exceeding_initialization_budget: bool = False, # Misc options torch_device: str | None = None, simplify_parameter_changes: bool = False, ) -> GenerationStrategy: """ Choose a generation strategy based on the provided method and options. Args: method: The method to use for generating candidates. Options are: - "fast": Uses Bayesian optimization, configured specifically for the current experiment. - "random_search": Uses random search. initialization_budget: Number of initialization trials. If None, will be automatically determined based on the search space. initialization_random_seed: Random seed for initialization. If None, no seed will be set. initialize_with_center: Whether to include the center of the search space in the initialization trials. use_existing_trials_for_initialization: Whether to use existing trials for initialization. min_observed_initialization_trials: Minimum number of observed init trials required before moving to the next generation step. allow_exceeding_initialization_budget: Whether to allow exceeding the initialization budget if more trials are needed. torch_device: The torch device to use for model fitting. If None, will use the default device. simplify_parameter_changes: Whether to use BONSAI [Daulton2026bonsai]_ to simplify parameter changes in arms generated via Bayesian Optimization by pruning irrelevant parameter changes. Returns: A GenerationStrategy instance configured according to the specified options. """ generation_strategy = choose_generation_strategy( struct=GenerationStrategyDispatchStruct( method=method, initialization_budget=initialization_budget, initialization_random_seed=initialization_random_seed, initialize_with_center=initialize_with_center, use_existing_trials_for_initialization=( use_existing_trials_for_initialization ), min_observed_initialization_trials=min_observed_initialization_trials, allow_exceeding_initialization_budget=( allow_exceeding_initialization_budget ), torch_device=torch_device, simplify_parameter_changes=simplify_parameter_changes, ) ) logger.info( f"{generation_strategy} chosen based on user input and problem structure." ) return generation_strategy # -------------------- Section 5.2: Metric configuration -------------------------- def _overwrite_metric(self, metric: Metric) -> None: """ Overwrite an existing Metric on the Experiment with the provided Metric if they share the same name. If not Metric with the same name exists, add the Metric as a tracking metric. Note that this method does not save the Experiment to the database (this is handled in self._set_metrics). """ if metric.name in self._experiment.metrics: self._experiment.update_metric(metric) else: self._experiment.add_tracking_metric(metric=metric) logger.warning( f"Metric {metric} not found on experiment, added as a new " "tracking metric." ) # -------------------- Section 5.3: Storage utilies ------------------------------- def _to_json_snapshot(self) -> dict[str, Any]: """Serialize this `AxClient` to JSON to be able to interrupt and restart optimization and save it to file by the provided path. Returns: A JSON-safe dict representation of this `AxClient`. """ # If the user has supplied custom encoder registries, use them. Otherwise use # the core encoder registries. if ( self._storage_config is not None and self._storage_config.registry_bundle is not None ): encoder_registry = ( self._storage_config.registry_bundle.sqa_config.json_encoder_registry ) class_encoder_registry = self._storage_config.registry_bundle.sqa_config.json_class_encoder_registry # noqa: E501 else: encoder_registry = CORE_ENCODER_REGISTRY class_encoder_registry = CORE_CLASS_ENCODER_REGISTRY return { "_type": self.__class__.__name__, "experiment": object_to_json( self._experiment, encoder_registry=encoder_registry, class_encoder_registry=class_encoder_registry, ), "generation_strategy": object_to_json( self._generation_strategy, encoder_registry=encoder_registry, class_encoder_registry=class_encoder_registry, ) if self._maybe_generation_strategy is not None else None, } @classmethod def _from_json_snapshot( cls, snapshot: dict[str, Any], storage_config: StorageConfig | None = None, ) -> Self: # If the user has supplied custom encoder registries, use them. Otherwise use # the core encoder registries. if storage_config is not None and storage_config.registry_bundle is not None: decoder_registry = ( storage_config.registry_bundle.sqa_config.json_decoder_registry ) class_decoder_registry = ( storage_config.registry_bundle.sqa_config.json_class_decoder_registry ) else: decoder_registry = CORE_DECODER_REGISTRY class_decoder_registry = CORE_CLASS_DECODER_REGISTRY # Decode the experiment, and generation strategy if present experiment = object_from_json( object_json=snapshot["experiment"], decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) generation_strategy = ( generation_strategy_from_json( generation_strategy_json=snapshot["generation_strategy"], experiment=experiment, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) if snapshot.get("generation_strategy") is not None else None ) client = cls(storage_config=storage_config) client.set_experiment(experiment=experiment) if generation_strategy is not None: client.set_generation_strategy(generation_strategy=generation_strategy) return client