Source code for ax.service.utils.best_point_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.core.experiment import Experiment
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.service.utils.best_point import (
    get_best_by_raw_objective_with_trial_index,
    get_best_parameters_from_model_predictions_with_trial_index,
    get_pareto_optimal_parameters,
)
from pyre_extensions import none_throws

BASELINE_ARM_NAME = "baseline_arm"


[docs] def get_best_trial_indices( experiment: Experiment, optimization_config: OptimizationConfig, generation_strategy: GenerationStrategy | None = None, trial_indices: list[int] | None = None, use_model_predictions: bool = False, ) -> list[int]: """Get the trial indices of the best trial(s) based on optimization type. Note: All best point methods used here only consider in-sample points (arms that have already been run on the experiment). They do not return parameters that have not been tried yet. Args: experiment: The experiment to get best trials from. optimization_config: The optimization config for the experiment. generation_strategy: The generation strategy, required for MOO or model predictions. trial_indices: The trial indices to consider when finding best trials. This should be pre-filtered by trial status. If None, considers all trials. use_model_predictions: If True, use model predictions instead of raw observations for best trial selection. Returns: A list of trial indices representing the best trial(s). For SOO, this contains at most one index. For MOO, this contains all Pareto optimal trial indices. """ if optimization_config.is_moo_problem: # For MOO, get the Pareto optimal parameters. pareto_optimal = get_pareto_optimal_parameters( experiment=experiment, generation_strategy=none_throws(generation_strategy), optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, ) return list(pareto_optimal.keys()) else: # For SOO, get the best trial if use_model_predictions: best_trial_result = ( get_best_parameters_from_model_predictions_with_trial_index( experiment=experiment, adapter=none_throws(generation_strategy).adapter, optimization_config=optimization_config, trial_indices=trial_indices, ) ) else: best_trial_result = get_best_by_raw_objective_with_trial_index( experiment=experiment, optimization_config=optimization_config, trial_indices=trial_indices, ) return [best_trial_result[0]] if best_trial_result is not None else []
[docs] def select_baseline_name_default_first_trial( experiment: Experiment, baseline_arm_name: str | None ) -> tuple[str, bool]: """ Choose a baseline arm from arms on the experiment. Logic: 1. If ``baseline_arm_name`` provided, validate that arm exists and return that arm name. 2. If ``experiment.status_quo`` is set, return its arm name. 3. If there is at least one trial on the experiment, use the first trial's first arm as the baseline. 4. Error if 1-3 all don't apply. Returns: Tuple: baseline arm name (str) true when baseline selected from first arm of experiment (bool) raise ValueError if no valid baseline found """ arms_dict = experiment.arms_by_name if baseline_arm_name: if baseline_arm_name not in arms_dict: raise ValueError(f"Arm by name {baseline_arm_name=} not found.") return baseline_arm_name, False if experiment.status_quo and none_throws(experiment.status_quo).name in arms_dict: baseline_arm_name = none_throws(experiment.status_quo).name return baseline_arm_name, False if ( experiment.trials and experiment.trials[0].arms and experiment.trials[0].arms[0].name in arms_dict ): baseline_arm_name = experiment.trials[0].arms[0].name return baseline_arm_name, True raise UserInputError("Could not find valid baseline arm.")