ax.utils
Common
Base
Constants
- class ax.utils.common.constants.Keys(*values)[source]
Bases:
StrEnumEnum of reserved keys in options dicts etc, alphabetized.
NOTE: Useful for keys in dicts that correspond to kwargs to classes or functions and/or are used in multiple places.
- ACQF_KWARGS = 'acquisition_function_kwargs'
- AX_ACQUISITION_KWARGS = 'ax_acquisition_kwargs'
- BATCH_INIT_CONDITIONS = 'batch_initial_conditions'
- CANDIDATE_SET = 'candidate_set'
- CANDIDATE_SIZE = 'candidate_size'
- COST_AWARE_UTILITY = 'cost_aware_utility'
- COST_INTERCEPT = 'cost_intercept'
- CURRENT_VALUE = 'current_value'
- DEFAULT_OBJECTIVE_NAME = 'objective'
- EXPAND = 'expand'
- EXPECTED_ACQF_VAL = 'expected_acquisition_value'
- EXPERIMENT_TOTAL_CONCURRENT_ARMS = 'total_concurrent_arms'
- FACTORIAL_PLUS_EMPIRICAL_BAYES_THOMPSON_SAMPLING = 'FACTORIAL + EMPIRICAL_BAYES_THOMPSON_SAMPLING'
- FIDELITY_FEATURES = 'fidelity_features'
- FIDELITY_WEIGHTS = 'fidelity_weights'
- FRAC_RANDOM = 'frac_random'
- FULL_PARAMETERIZATION = 'full_parameterization'
- IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = 'immutable_search_space_and_opt_config'
- LILO_INPUT_HASH = 'lilo_input_hash'
- LILO_LABELING = 'lilo_labeling'
- LLM_MESSAGES = 'llm_messages'
- LONG_RUN = 'long_run'
- MAXIMIZE = 'maximize'
- METADATA = 'metadata'
- METRIC_NAMES = 'metric_names'
- NUM_FANTASIES = 'num_fantasies'
- NUM_INNER_RESTARTS = 'num_inner_restarts'
- NUM_RESTARTS = 'num_restarts'
- NUM_TRACE_OBSERVATIONS = 'num_trace_observations'
- OPTIMIZER_KWARGS = 'optimizer_kwargs'
- PAIRWISE_PREFERENCE_QUERY = 'pairwise_pref_query'
- PREFERENCE_DATA = 'preference_data'
- PROJECT = 'project'
- QMC = 'qmc'
- RAW_INNER_SAMPLES = 'raw_inner_samples'
- RAW_SAMPLES = 'raw_samples'
- RESUMED_FROM_STORAGE_TS = 'resumed_from_storage_timestamps'
- SAMPLER = 'sampler'
- SEED_INNER = 'seed_inner'
- SEQUENTIAL = 'sequential'
- SHORT_RUN = 'short_run'
- STATE_DICT = 'state_dict'
- SUBCLASS = 'subclass'
- SUBSET_MODEL = 'subset_model'
- TASK_FEATURES = 'task_features'
- TASK_FEATURE_NAME = 'task_feature'
- TRIAL_COMPLETION_TIMESTAMP = 'trial_completion_timestamp'
- UNKNOWN_GENERATION_NODE = 'unknown_gen_node'
- UNNAMED_ARM = 'unnamed_arm'
- WARMSTART_TRIAL_MODEL_KEY = 'generation_model_key'
- WARM_START_REFITTING = 'warm_start_refitting'
- X_BASELINE = 'X_baseline'
- ax.utils.common.constants.is_preference_metric(metric_name: str) bool[source]
Check if a metric uses preference/comparison-pair semantics.
Preference metrics use latent utility models (e.g., PairwiseGP) rather than standard per-arm regression. Some analyses work on preference metrics (e.g., sensitivity analysis on the utility function), while others assume regression semantics and should exclude them (e.g., cross-validation, arm effects).
Extend this check when integrating new preference models (e.g., VariationalTopChoiceGP).
Decorator
- class ax.utils.common.decorator.ClassDecorator[source]
Bases:
ABCTemplate for making a decorator work as a class level decorator. That decorator should extend ClassDecorator. It must implement __init__ and decorate_callable. See disable_logger.decorate_callable for an example. decorate_callable should call self._call_func() instead of directly calling func to handle static functions. Note: _call_func is still imperfect and unit tests should be used to ensure everything is working properly. There is a lot of complexity in detecting classmethods and staticmethods and removing the self argument in the right situations. For best results always use keyword args in the decorated class.
DECORATE_PRIVATE can be set to determine whether private methods should be decorated. In the case of a logging decorator, you may only want to decorate things the user calls. But in the case of a disable logging decorator, you may want to decorate everything to ensure no logs escape.
- DECORATE_PRIVATE = True
Deprecation
Docutils
Support functions for sphinx et. al
- ax.utils.common.docutils.copy_doc(src: Callable[[...], Any]) Callable[[_T], _T][source]
A decorator that copies the docstring of another object
Since
sphinxactually loads the python modules to grab the docstrings this works with bothsphinxand thehelpfunction.class Cat(Mamal): @property @copy_doc(Mamal.is_feline) def is_feline(self) -> true: ...
Equality
- ax.utils.common.equality.dataframe_equals(df1: DataFrame, df2: DataFrame) bool[source]
Compare equality of two pandas dataframes.
- ax.utils.common.equality.datetime_equals(dt1: datetime | None, dt2: datetime | None) bool[source]
Compare equality of two datetimes, up to a difference of one second.
- ax.utils.common.equality.equality_typechecker(eq_func: Callable) Callable[source]
A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type.
- ax.utils.common.equality.is_ax_equal(one_val: Any, other_val: Any) bool[source]
Check for equality of two values, handling lists, dicts, dfs, floats, dates, and numpy arrays. This method and
same_elementsfunction as a recursive unit.Some special cases: - For datetime objects, the equality is checked up to a tolerance of one second. - For floats,
np.iscloseis used to check for almost-equality. - For lists (and dict values),same_elementsis used. This ignoresthe ordering of the elements, and checks that the two lists are subsets of each other (under the assumption that there are no duplicates).
If the objects don’t fall into any of the special cases, we use simple equality check and cast the output to a boolean. If the comparison or cast fails, we return False. Example: the comparison of a float with a numpy array (with multiple elements) will return False.
- ax.utils.common.equality.object_attribute_dicts_equal(one_dict: dict[str, Any], other_dict: dict[str, Any], skip_db_id_check: bool = False) bool[source]
Utility to check if all items in attribute dicts of two Ax objects are the same.
NOTE: Special-cases some Ax object attributes, like “_experiment”, where full equality is hard to check.
- Parameters:
one_dict – First object’s attribute dict (
obj.__dict__).other_dict – Second object’s attribute dict (
obj.__dict__).skip_db_id_check – If
True, will exclude thedb_idattributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to the version reloaded from the DB).
- ax.utils.common.equality.object_attribute_dicts_find_unequal_fields(one_dict: dict[str, Any], other_dict: dict[str, Any], fast_return: bool = True, skip_db_id_check: bool = False) tuple[dict[str, tuple[Any, Any]], dict[str, tuple[Any, Any]]][source]
Utility for finding out what attributes of two objects’ attribute dicts are unequal.
- Parameters:
one_dict – First object’s attribute dict (
obj.__dict__).other_dict – Second object’s attribute dict (
obj.__dict__).fast_return – Boolean representing whether to return as soon as a single unequal attribute was found or to iterate over all attributes and collect all unequal ones.
skip_db_id_check – If
True, will exclude thedb_idattributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to the version reloaded from the DB).
- Returns:
attribute name to attribute values of unequal type (as a tuple),
attribute name to attribute values of unequal value (as a tuple).
- Return type:
Two dictionaries
- ax.utils.common.equality.same_elements(list1: list[Any], list2: list[Any]) bool[source]
Compare equality of two lists of core Ax objects.
- Assumptions:
– The contents of each list are types that implement __eq__ – The lists do not contain duplicates
Checking equality is then the same as checking that the lists are the same length, and that both are subsets of the other.
Executils
- ax.utils.common.executils.allowed_to_fail(allowed_error_types: tuple[type[Exception], ...] | None = None, handler: Callable[[Exception], None] | None = None, logger: Logger | None = None) Iterator[list[Exception | None]][source]
A context manager to catch and handle exceptions of a specific type. :param allowed_error_types: Types of exceptions to catch. :type allowed_error_types: Tuple[Type[Exception], …] :param handler: An optional function to handle the exception.
If not provided, the exception will be ignored. Defaults to None.
- Parameters:
logger (Logger, optional) – An optional logger to log the exception.
- ax.utils.common.executils.execute_with_timeout(partial_func: Callable[[...], T], timeout: float) T[source]
Execute a function in a thread that we can abandon if it takes too long. The thread cannot actually be terminated, so the process will keep executing after timeout, but not on the main thread.
- Parameters:
partial_func – A partial function to execute. This should either be a function that takes no arguments, or a functools.partial function with all arguments bound.
timeout – The timeout in seconds.
- Returns:
The return value of the partial function when called.
- ax.utils.common.executils.handle_exceptions_in_retries(no_retry_exceptions: tuple[type[Exception], ...], retry_exceptions: tuple[type[Exception], ...], suppress_errors: bool, check_message_contains: list[str] | None, last_retry: bool, logger: Logger | None, wrap_error_message_in: str | None) Generator[None, None, None][source]
- ax.utils.common.executils.retry_on_exception(exception_types: tuple[type[Exception], ...] | None = None, no_retry_on_exception_types: tuple[type[Exception], ...] | None = None, check_message_contains: list[str] | None = None, retries: int = 3, suppress_all_errors: bool = False, logger: Logger | None = None, default_return_on_suppression: Any | None = None, wrap_error_message_in: str | None = None, initial_wait_seconds: int | None = None) Any | None[source]
A decorator for instance methods or standalone functions that makes them retry on failure and allows to specify on which types of exceptions the function should and should not retry.
NOTE: If the argument suppress_all_errors is supplied and set to True, the error will be suppressed and default value returned.
- Parameters:
exception_types – A tuple of exception(s) types to catch in the decorated function. If none is provided, baseclass Exception will be used.
no_retry_on_exception_types – Exception types to consider non-retryable even if their supertype appears in exception_types or the only exceptions to not retry on if no exception_types are specified.
check_message_contains – A list of strings, against which to match error messages. If the error message contains any one of these strings, the exception will cause a retry. NOTE: This argument works in addition to exception_types; if those are specified, only the specified types of exceptions will be caught and retried on if they contain the strings provided as check_message_contains.
retries – Number of retries to perform.
suppress_all_errors – If true, after all the retries are exhausted, the error will still be suppressed and default_return_on_suppresion will be returned from the function. NOTE: If using this argument, the decorated function may not actually get fully executed, if it consistently raises an exception.
logger – A handle for the logger to be used.
default_return_on_suppression – If the error is suppressed after all the retries, then this default value will be returned from the function. Defaults to None.
wrap_error_message_in – If raising the error message after all the retries, a string wrapper for the error message (useful for making error messages more user-friendly). NOTE: Format of resulting error will be: “<wrap_error_message_in>: <original_error_type>: <original_error_msg>”, with the stack trace of the original message.
initial_wait_seconds – Initial length of time to wait between failures, doubled after each failure up to a maximum of 10 minutes. If unspecified then there is no wait between retries.
FuncEnum
- class ax.utils.common.func_enum.FuncEnum(new_class_name, /, names, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]
Bases:
EnumA base class for all enums with the following structure: string values that map to names of functions, which reside in the same module as the enum.
Kwargs
- ax.utils.common.kwargs.consolidate_kwargs(kwargs_iterable: Iterable[Mapping[str, Any] | None], keywords: Iterable[str]) dict[str, Any][source]
Combine an iterable of kwargs into a single dict of kwargs, where kwargs by duplicate keys that appear later in the iterable get priority over the ones that appear earlier and only kwargs referenced in keywords will be used. This allows to combine somewhat redundant sets of kwargs, where a user-set kwarg, for instance, needs to override a default kwarg.
>>> consolidate_kwargs( ... kwargs_iterable=[{'a': 1, 'b': 2}, {'b': 3, 'c': 4, 'd': 5}], ... keywords=['a', 'b', 'd'] ... ) {'a': 1, 'b': 3, 'd': 5}
- ax.utils.common.kwargs.filter_kwargs(function: Callable, **kwargs: Any) Any[source]
Filter out kwargs that are not applicable for a given function. Return a copy of given kwargs dict with only the required kwargs.
- ax.utils.common.kwargs.get_function_argument_names(function: Callable, omit: list[str] | None = None) list[str][source]
Extract parameter names from function signature.
- ax.utils.common.kwargs.get_function_default_arguments(function: Callable) dict[str, Any][source]
Extract default arguments from function signature.
- ax.utils.common.kwargs.warn_on_kwargs(callable_with_kwargs: Callable, **kwargs: Any) None[source]
Log a warning when a decoder function receives unexpected kwargs.
NOTE: This mainly caters to the use case where an older version of Ax is used to decode objects, serialized to JSON by a newer version of Ax (and therefore potentially containing new fields). In that case, the decoding function should not fail when encountering those additional fields, but rather just ignore them and log a warning using this function.
Logger
- class ax.utils.common.logger.AxOutputNameFilter(name='')[source]
Bases:
FilterThis is a filter which sets the record’s output_name, if not configured
- ax.utils.common.logger.build_file_handler(filepath: str, level: int = 20) StreamHandler[Any][source]
Build a file handle that logs entries to the given file, using the same formatting as the stream handler.
- Parameters:
filepath – Location of the file to log output to. If the file exists, output will be appended. If it does not exist, a new file will be created.
level – The log level. By default, sets level to INFO
- Returns:
A logging.FileHandler instance
- ax.utils.common.logger.build_stream_handler(level: int = 20) StreamHandler[Any][source]
Build the default stream handler used for most Ax logging. Sets default level to INFO, instead of WARNING.
- Parameters:
level – The log level. By default, sets level to INFO
- Returns:
A logging.StreamHandler instance
- class ax.utils.common.logger.disable_logger(name: str, level: int = 40)[source]
Bases:
ClassDecorator
- class ax.utils.common.logger.disable_loggers(names: list[str], level: int = 40)[source]
Bases:
ClassDecorator
- ax.utils.common.logger.get_logger(name: str, level: int = 20, force_name: bool = False) Logger[source]
Get an Axlogger.
To set a human-readable “output_name” that appears in logger outputs, add {“output_name”: “[MY_OUTPUT_NAME]”} to the logger’s contextual information. By default, we use the logger’s name
NOTE: To change the log level on particular outputs (e.g. STDERR logs), set the proper log level on the relevant handler, instead of the logger e.g. logger.handers[0].setLevel(INFO)
- Parameters:
name – The name of the logger.
level – The level at which to actually log. Logs below this level of importance will be discarded
force_name – If set to false and the module specified is not ultimately a descendent of the ax module specified by name, “ax.” will be prepended to name
- Returns:
The logging.Logger object.
- ax.utils.common.logger.make_indices_str(indices: Iterable[int]) str[source]
Generate a string representation of an iterable of indices; if indices are contiguous, returns a string formatted like like ‘<min_idx> - <max_idx>’, otherwise a string formatted like ‘[idx_1, idx_2, …, idx_n’].
Mock Torch
- ax.utils.common.mock.mock_patch_method_original(mock_path: str, original_method: Callable[[...], T]) MagicMock[source]
Context manager for patching a method returning type T on class C, to track calls to it while still executing the original method. There is not a native way to do this with mock.patch.
Random
- ax.utils.common.random.set_rng_seed(seed: int) None[source]
Sets seeds for random number generators from numpy, pytorch, and the native random module.
- Parameters:
seed – The random number generator seed.
- ax.utils.common.random.with_rng_seed(seed: int | None) Generator[None, None, None][source]
Context manager that sets the random number generator seeds to a given value and restores the previous state on exit.
If the seed is None, the context manager does nothing. This makes it possible to use the context manager without having to change the code based on whether the seed is specified.
- Parameters:
seed – The random number generator seed.
Result
- class ax.utils.common.result.Err(value: E)[source]
Bases:
Generic[T,E],Result[T,E]Contains the error value.
- property err: E
- map(op: Callable[[T], U]) Result[U, E][source]
Maps a Result[T, E] to Result[U, E] by applying a function to a contained Ok value, leaving an Err value untouched. This function can be used to compose the results of two functions.
- map_err(op: Callable[[E], F]) Result[T, F][source]
Maps a Result[T, E] to Result[T, F] by applying a function to a contained Err value, leaving an Ok value untouched. This function can be used to pass through a successful result while handling an error.
- map_or(default: U, op: Callable[[T], U]) U[source]
Returns the provided default (if Err), or applies a function to the contained value (if Ok).
- map_or_else(default_op: Callable[[], U], op: Callable[[T], U]) U[source]
Maps a Result[T, E] to U by applying fallback function default to a contained Err value, or function op to a contained Ok value. This function can be used to unpack a successful result while handling an error.
- unwrap() NoReturn[source]
Returns the contained Ok value.
Because this function may raise a RuntimeError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_err() E[source]
Returns the contained Err value.
Because this function may raise a RuntimeError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_or_else(op: Callable[[E], T]) T[source]
Returns the contained Ok value or computes it from a Callable.
- property value: E
- class ax.utils.common.result.ExceptionE(message: str, exception: Exception)[source]
Bases:
objectA class that holds an Exception and can be used as the E type in Result[T, E].
- class ax.utils.common.result.Ok(value: T)[source]
Bases:
Generic[T,E],Result[T,E]Contains the success value.
- map(op: Callable[[T], U]) Result[U, E][source]
Maps a Result[T, E] to Result[U, E] by applying a function to a contained Ok value, leaving an Err value untouched. This function can be used to compose the results of two functions.
- map_err(op: Callable[[E], F]) Result[T, F][source]
Maps a Result[T, E] to Result[T, F] by applying a function to a contained Err value, leaving an Ok value untouched. This function can be used to pass through a successful result while handling an error.
- map_or(default: U, op: Callable[[T], U]) U[source]
Returns the provided default (if Err), or applies a function to the contained value (if Ok).
- map_or_else(default_op: Callable[[], U], op: Callable[[T], U]) U[source]
Maps a Result[T, E] to U by applying fallback function default to a contained Err value, or function op to a contained Ok value. This function can be used to unpack a successful result while handling an error.
- property ok: T
- unwrap() T[source]
Returns the contained Ok value.
Because this function may raise a RuntimeError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_err() NoReturn[source]
Returns the contained Err value.
Because this function may raise a RuntimeError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_or_else(op: Callable[[E], T]) T[source]
Returns the contained Ok value or computes it from a Callable.
- property value: T
- class ax.utils.common.result.Result[source]
-
A minimal implementation of a rusty Result monad. See https://doc.rust-lang.org/std/result/enum.Result.html for more information.
- abstractmethod map(op: Callable[[T], U]) Result[U, E][source]
Maps a Result[T, E] to Result[U, E] by applying a function to a contained Ok value, leaving an Err value untouched. This function can be used to compose the results of two functions.
- abstractmethod map_err(op: Callable[[E], F]) Result[T, F][source]
Maps a Result[T, E] to Result[T, F] by applying a function to a contained Err value, leaving an Ok value untouched. This function can be used to pass through a successful result while handling an error.
- abstractmethod map_or(default: U, op: Callable[[T], U]) U[source]
Returns the provided default (if Err), or applies a function to the contained value (if Ok).
- abstractmethod map_or_else(default_op: Callable[[], U], op: Callable[[T], U]) U[source]
Maps a Result[T, E] to U by applying fallback function default to a contained Err value, or function op to a contained Ok value. This function can be used to unpack a successful result while handling an error.
- abstractmethod unwrap() T[source]
Returns the contained Ok value.
Because this function may raise a RuntimeError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- abstractmethod unwrap_err() E[source]
Returns the contained Err value.
Because this function may raise a RuntimeError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- abstractmethod unwrap_or(default: U) T | U[source]
Returns the contained Ok value or a provided default.
- abstractmethod unwrap_or_else(op: Callable[[E], T]) T[source]
Returns the contained Ok value or computes it from a Callable.
- abstract property value: T | E
Serialization
- class ax.utils.common.serialization.SerializationMixin[source]
Bases:
objectBase class for Ax objects that define their JSON serialization and deserialization logic at the class level, e.g. most commonly
RunnerandMetricsubclasses.NOTE: Using this class for Ax objects that receive other Ax objects as inputs, is recommended only iff the parent object (that would be inheriting from this base class) is not enrolled into CORE_ENCODER/DECODER_REGISTRY. Inheriting from this mixin with an Ax object that is in CORE_ENCODER/DECODER_REGISTRY, will result in a circular dependency, so such classes should inplement their encoding and decoding logic within the json_store module and not on the classes.
For example, TransitionCriterion take TrialStatus as inputs and are defined on the CORE_ENCODER/DECODER_REGISTRY, so TransitionCriterion should not inherit from SerializationMixin and should define custom encoding/decoding logic within the json_store module.
- classmethod deserialize_init_args(args: dict[str, Any], decoder_registry: dict[str, type[T] | Callable[[...], T]] | None = None, class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]] | None = None) dict[str, Any][source]
Given a dictionary, deserialize the properties needed to initialize the object. Used for storage.
- ax.utils.common.serialization.extract_init_args(args: dict[str, Any], class_: type[Any]) dict[str, Any][source]
Given a dictionary, extract the arguments required for the given class’s constructor.
Testutils
Support functions for tests
- class ax.utils.common.testutils.TestCase(methodName: str = 'runTest')[source]
Bases:
TestCaseThe base Ax test case, contains various helper functions to write unittests.
- assertAllClose(input: Any, other: Any, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) None[source]
Assert that two tensors are close.
Calls torch.testing.assert_close, using the signature and default behavior of torch.allclose.
The formula asserted is abs(input - other) <= atol + rtol * abs(other).
- Parameters:
input – First tensor or tensor-or-scalar-like to compare
other – Second tensor or tensor-or-scalar-like to compare
rtol – Relative tolerance
atol – Absolute tolerance
equal_nan – If True, consider NaN values as equal
- Example output:
AssertionError: Scalars are not close!
Absolute difference: 1.0000034868717194 (up to 0.0001 allowed) Relative difference: 0.8348668001940709 (up to 1e-05 allowed)
- assertAxBaseEqual(first: Base, second: Base, msg: str | None = None, skip_db_id_check: bool = False) None[source]
Check that two Ax objects that subclass
Baseare equal or raise assertion error otherwise.- Parameters:
first –
Base-subclassing object to compare tosecond.second –
Base-subclassing object to compare tofirst.msg – Message to put into the assertion error raised on inequality; if not specified, a default message is used.
skip_db_id_check –
If
True, will exclude thedb_idattributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to theversion reloaded from the DB).
- assertDictsAlmostEqual(a: dict[str, Any], b: dict[str, Any], consider_nans_equal: bool = False) None[source]
Testing utility that checks that 1) the keys of a and b are identical, and that 2) the values of a and b are almost equal if they have a floating point type, considering NaNs as equal, and otherwise just equal.
- Parameters:
test – The test case object.
a – A dictionary.
b – Another dictionary.
consider_nans_equal – Whether to consider NaNs equal when comparing floating point numbers.
- assertEqual(first: Any, second: Any, msg: str | None = None) None[source]
Fail if the two objects are unequal as determined by the ‘==’ operator.
- assertIsSubDict(subdict: dict[str, Any], superdict: dict[str, Any], almost_equal: bool = False, consider_nans_equal: bool = False) None[source]
Testing utility that checks that all keys and values of subdict are contained in dict.
- Parameters:
subdict – A smaller dictionary.
superdict – A larger dictionary which should contain all keys of subdict and the same values as subdict for the corresponding keys.
- assertRaisesOn(exc: type[Exception], line: str | None = None, regex: str | None = None) _AssertRaisesContextOn[source]
Assert that an exception is raised on a specific line.
Timeutils
- ax.utils.common.timeutils.current_timestamp_in_millis() int[source]
Grab current timestamp in milliseconds as an int.
Typeutils
- ax.utils.common.typeutils.assert_is_instance_dict(d: dict[X, Y], key_type: type[K], val_type: type[V]) dict[K, V][source]
Asserts that all keys and values in the dictionary are instances of the given classes.
- Parameters:
d – the dictionary to check
key_type – the type to check against for keys
val_type – the type to check against for values
- Returns:
the d argument, unchanged
- ax.utils.common.typeutils.assert_is_instance_list(old_l: list[V], typ: type[T]) list[T][source]
Asserts that all items in a list are instances of the given type.
- Parameters:
old_l – the list to check
typ – the type to check against
- Returns:
the old_l argument, unchanged
Typeutils Non-Native
Typeutils Torch
Measurement
Synthetic Functions
- class ax.utils.measurement.synthetic_functions.Aug_Branin[source]
Bases:
SyntheticFunctionAugmented Branin function (3-dimensional with infinitely many global minima).
- class ax.utils.measurement.synthetic_functions.Aug_Hartmann6[source]
Bases:
Hartmann6Augmented Hartmann6 function (7-dimensional with 1 global minimum).
- class ax.utils.measurement.synthetic_functions.Branin[source]
Bases:
SyntheticFunctionBranin function (2-dimensional with 3 global minima).
- class ax.utils.measurement.synthetic_functions.FromBotorch(botorch_synthetic_function: SyntheticTestFunction)[source]
Bases:
SyntheticFunction
- class ax.utils.measurement.synthetic_functions.Hartmann6[source]
Bases:
SyntheticFunctionHartmann6 function (6-dimensional with 1 global minimum).
- class ax.utils.measurement.synthetic_functions.SyntheticFunction[source]
Bases:
ABC- property domain: list[tuple[float, float]]
Domain on which function is evaluated.
The list is of the same length as the dimensionality of the inputs, where each element of the list is a tuple corresponding to the min and max of the domain for that dimension.
- f(X: ndarray[tuple[Any, ...], dtype[_ScalarT]]) float | ndarray[tuple[Any, ...], dtype[_ScalarT]][source]
Synthetic function implementation.
- Parameters:
X (numpy.ndarray) – an n by d array, where n represents the number of observations and d is the dimensionality of the inputs.
- Returns:
an n-dimensional array.
- Return type:
numpy.ndarray
- property maximums: list[tuple[float, ...]]
List of global minimums.
Each element of the list is a d-tuple, where d is the dimensionality of the inputs. There may be more than one global minimums.
- ax.utils.measurement.synthetic_functions.from_botorch(botorch_synthetic_function: SyntheticTestFunction) SyntheticFunction[source]
Utility to generate Ax synthetic functions from BoTorch synthetic functions.
Notebook
Plotting
Report
Render
- ax.utils.report.render.link_html(text: str, href: str) str[source]
Embed text and reference address into link tag.
- ax.utils.report.render.render_report_elements(experiment_name: str, html_elements: list[str], header: bool = True, offline: bool = False, notebook_env: bool = False) str[source]
Generate Ax HTML report for a given experiment from HTML elements.
Uses Jinja2 for template. Injects Plotly JS for graph rendering.
Example:
html_elements = [ h2_html("Subsection with plot"), p_html("This is an example paragraph."), plot_html(plot_fitted(gp_model, 'perf_metric')), h2_html("Subsection with table"), pandas_html(data.df), ] html = render_report_elements('My experiment', html_elements)
- Parameters:
experiment_name – the name of the experiment to use for title.
html_elements – list of HTML strings to render in report body.
header – if True, render experiment title as a header. Meant to be used for standalone reports (e.g. via email), as opposed to served on the front-end.
offline – if True, entire Plotly library is bundled with report.
notebook_env – if True, caps the report width to 700px for viewing in a notebook environment.
- Returns:
HTML string.
- Return type:
- ax.utils.report.render.table_cell_html(text: str, width: str | None = None) str[source]
Embed text or an HTML element into table cell tag.
- ax.utils.report.render.table_heading_cell_html(text: str) str[source]
Embed text or an HTML element into table heading cell tag.
- ax.utils.report.render.table_html(table_rows: list[str]) str[source]
Embed list of HTML elements into table tag.
Sensitivity
Derivative GP
- ax.utils.sensitivity.derivative_gp.get_KXX_inv(gp: Model) Tensor[source]
Get the inverse matrix of K(X,X). :param gp: Botorch model.
- Returns:
The inverse of K(X,X).
- ax.utils.sensitivity.derivative_gp.get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = 'rbf') Tensor[source]
Computes the analytic derivative of the kernel K(x,X) w.r.t. x. :param gp: Botorch model. :param x: (n x D) Test points. :param kernel_type: Takes “rbf” or “matern”
- Returns:
Tensor (n x D) The derivative of the kernel K(x,X) w.r.t. x.
- ax.utils.sensitivity.derivative_gp.get_Kxx_dx2(gp: Model, kernel_type: str = 'rbf') Tensor[source]
Computes the analytic second derivative of the kernel w.r.t. the training data :param gp: Botorch model. :param kernel_type: Takes “rbf” or “matern”
- Returns:
Tensor (n x D x D) The second derivative of the kernel w.r.t. the training data.
- ax.utils.sensitivity.derivative_gp.posterior_derivative(gp: Model, x: Tensor, kernel_type: str = 'rbf') MultivariateNormal[source]
Computes the posterior of the derivative of the GP w.r.t. the given test points x. This follows the derivation used by GIBO in Sarah Muller, Alexander von Rohr, Sebastian Trimpe. “Local policy search with Bayesian optimization”, Advances in Neural Information Processing Systems 34, NeurIPS 2021. :param gp: Botorch model :param x: (n x D) Test points. :param kernel_type: Takes “rbf” or “matern”
- Returns:
A Botorch Posterior.
Derivative Measures
- class ax.utils.sensitivity.derivative_measures.GpDGSMGpMean(model: Model, bounds: Tensor, derivative_gp: bool = False, kernel_type: str | None = None, Y_scale: float = 1.0, num_mc_samples: int = 10000, input_qmc: bool = False, dtype: dtype = torch.float64, num_bootstrap_samples: int = 1, discrete_features: list[int] | None = None)[source]
Bases:
object- gradient_absolute_measure() Tensor[source]
Computes the gradient absolute measure:
- Returns:
- if self.num_bootstrap_samples > 1
Tensor: (values, var_mc, stderr_mc) x dim
- else
Tensor: (values) x dim
- gradient_measure() Tensor[source]
Computes the gradient measure:
- Returns:
- if self.num_bootstrap_samples > 1
Tensor: (values, var_mc, stderr_mc) x dim
- else
Tensor: (values) x dim
- class ax.utils.sensitivity.derivative_measures.GpDGSMGpSampling(model: Model, bounds: Tensor, num_gp_samples: int, derivative_gp: bool = False, kernel_type: str | None = None, Y_scale: float = 1.0, num_mc_samples: int = 10000, input_qmc: bool = False, gp_sample_qmc: bool = False, dtype: dtype = torch.float64, num_bootstrap_samples: int = 1)[source]
Bases:
GpDGSMGpMean
- ax.utils.sensitivity.derivative_measures.compute_derivatives_from_model_list(model_list: Sequence[Model], bounds: Tensor, discrete_features: list[int] | None = None, fixed_features: dict[int, float] | None = None, **kwargs: Any) Tensor[source]
Computes average derivatives of a list of models on a bounded domain. Estimation is according to the GP posterior mean function.
- Parameters:
model_list – A list of m botorch.models.model.Model types for which to compute the average derivative.
bounds – A 2 x d Tensor of lower and upper bounds of the domain of the models.
discrete_features – If specified, the inputs associated with the indices in this list are generated using an integer-valued uniform distribution, rather than the default (pseudo-)random continuous uniform distribution.
fixed_features – If specified, a dictionary mapping feature indices to fixed values. These features will be held constant and their derivatives will not be computed. The bounds tensor should include all features.
kwargs – Passed along to GpDGSMGpMean.
- Returns:
A (m x d’) tensor of gradient measures, where d’ is the number of non-fixed features.
- ax.utils.sensitivity.derivative_measures.sample_discrete_parameters(input_mc_samples: Tensor, discrete_features: None | list[int], bounds: Tensor, num_mc_samples: int) Tensor[source]
Samples the input parameters uniformly at random for the discrete features.
- Parameters:
input_mc_samples – The input mc samples tensor to be modified.
discrete_features – A list of integers (or None) of indices corresponding to discrete features.
bounds – The parameter bounds.
num_mc_samples – The number of Monte Carlo grid samples.
- Returns:
A modified input mc samples tensor.
Sobol Measures
- class ax.utils.sensitivity.sobol_measures.SobolSensitivity(bounds: Tensor, input_function: Callable[[Tensor], Tensor] | None = None, num_mc_samples: int = 10000, input_qmc: bool = False, second_order: bool = False, first_order_idcs: Tensor | None = None, num_bootstrap_samples: int = 1, bootstrap_array: bool = False, discrete_features: list[int] | None = None)[source]
Bases:
object- evalute_function(f_A_B_ABi: Tensor | None = None) None[source]
- evaluates the objective function and devides the evaluation into
torch.Tensors needed for the indices computation.
- Parameters:
f_A_B_ABi – Function evaluations on the entire grid of size M(d+2).
- first_order_indices() Tensor[source]
Computes the first order Sobol indices:
- Returns:
- if num_bootstrap_samples>1
Tensor: (values,var_mc,stderr_mc)x dim
- else
Tensor: (values)x dim
- second_order_indices(first_order_idcs: Tensor | None = None, first_order_idcs_btsp: Tensor | None = None) Tensor[source]
Computes the Second order Sobol indices: :param first_order_idcs: Tensor of previously computed first order indices, where
first_order_idcs.shape = torch.Size([dim]).
- Parameters:
first_order_idcs_btsp – Tensor of all first order indices given by bootstrap.
- Returns:
- if num_bootstrap_samples>1
Tensor: (values,var_mc,stderr_mc)x dim
- else
Tensor: (values)x dim
- class ax.utils.sensitivity.sobol_measures.SobolSensitivityGPMean(model: ~botorch.models.gpytorch.GPyTorchModel, bounds: ~torch.Tensor, num_mc_samples: int = 10000, second_order: bool = False, input_qmc: bool = False, num_bootstrap_samples: int = 1, link_function: ~collections.abc.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function GaussianLinkMean>, discrete_features: list[int] | None = None)[source]
Bases:
object- first_order_indices() Tensor[source]
Computes the first order Sobol indices:
- Returns:
- if num_bootstrap_samples>1
Tensor: (values,var_mc,stderr_mc)x dim
- else
Tensor: (values)x dim
- class ax.utils.sensitivity.sobol_measures.SobolSensitivityGPSampling(model: Model, bounds: Tensor, num_gp_samples: int = 1000, num_mc_samples: int = 10000, second_order: bool = False, input_qmc: bool = False, gp_sample_qmc: bool = False, num_bootstrap_samples: int = 1, discrete_features: list[int] | None = None)[source]
Bases:
object- first_order_indices() Tensor[source]
Computes the first order Sobol indices:
- Returns:
- if num_bootstrap_samples>1
Tensor: (values, var_gp, stderr_gp, var_mc, stderr_mc) x dim
- else
Tensor: (values, var, stderr) x dim
- ax.utils.sensitivity.sobol_measures.array_with_string_indices_to_dict(rows: list[str], cols: list[str], A: ndarray[tuple[Any, ...], dtype[_ScalarT]]) dict[str, dict[str, ndarray[tuple[Any, ...], dtype[_ScalarT]]]][source]
- Parameters:
rows – A list of strings with which to index rows of A.
cols – A list of strings with which to index columns of A.
A – A matrix, with len(rows) rows and len(cols) columns.
- Returns:
A dictionary dict that satisfies dict[rows[i]][cols[j]] = A[i, j].
- ax.utils.sensitivity.sobol_measures.ax_parameter_sens(adapter: TorchAdapter, metrics: list[str] | None = None, order: str = 'first', signed: bool = True, exclude_map_key: bool = True, exclude_task: bool = False, **sobol_kwargs: Any) dict[str, dict[str, ndarray[tuple[Any, ...], dtype[_ScalarT]]]][source]
Compute sensitivity for all metrics on an TorchAdapter.
Sobol measures are always positive regardless of the direction in which the parameter influences f. If signed is set to True, then the Sobol measure for each parameter will be given as its sign the sign of the average gradient with respect to that parameter across the search space. Thus, important parameters that, when increased, decrease f will have large and negative values; unimportant parameters will have values close to 0.
- Parameters:
adapter – A Adapter object with models that were fit.
metrics – The names of the metrics and outcomes for which to compute sensitivities. This should preferably be metrics with a good model fit. Defaults to adapter.outcomes.
order – A string specifying the order of the Sobol indices to be computed. Supports “first” and “total” and defaults to “first”.
signed – A bool for whether the measure should be signed.
exclude_map_key – If True (default), the MAP_KEY (“step”) feature will be excluded from sensitivity analysis by fixing it at the maximum step value. This makes the sensitivity analysis more interpretable for users who care about the effect of parameters on final performance.
exclude_task – If True, task parameters (those with
is_task=True, e.g. synthetic parameters from the TrialAsTask transform) will be excluded from the sensitivity results.sobol_kwargs – keyword arguments passed on to SobolSensitivityGPMean, and if signed, GpDGSMGpMean.
- Returns:
{‘parameter_name’ or (parameter_name_1, ‘parameter_name_2’): sensitivity_value}}, where the sensitivity value is cast to a Numpy array in order to be compatible with plot_feature_importance_by_feature.
- Return type:
Dictionary {‘metric_name’
- ax.utils.sensitivity.sobol_measures.compute_sobol_indices_from_model_list(model_list: list[GPyTorchModel], bounds: Tensor, order: str = 'first', discrete_features: list[int] | None = None, fixed_features: dict[int, float] | None = None, **sobol_kwargs: Any) Tensor[source]
Computes Sobol indices of a list of models on a bounded domain.
- Parameters:
model_list – A list of botorch.models.model.Model types for which to compute the Sobol indices.
bounds – A 2 x d Tensor of lower and upper bounds of the domain of the models.
order – A string specifying the order of the Sobol indices to be computed. Supports “first”, “second” and “total” and defaults to “first”. “total” computes the importance of a variable considering its main effect and all of its higher-order interactions, whereas “first” and “second” the variance when altering the variable in isolation or with one other variable, respectively.
discrete_features – If specified, the inputs associated with the indices in this list are generated using an integer-valued uniform distribution, rather than the default (pseudo-)random continuous uniform distribution.
fixed_features – If specified, a dictionary mapping feature indices to fixed values. These features will be held constant during sensitivity analysis, and their sensitivity will not be computed. The bounds tensor should still include all features (including fixed ones).
sobol_kwargs – keyword arguments passed on to SobolSensitivityGPMean.
- Returns:
With m GPs, returns a (m x d’) tensor of order-order Sobol indices, where d’ is the number of non-fixed features.
Stats
Statstools
- ax.utils.stats.statstools.agresti_coull_sem(n_numer: Series | ndarray[tuple[Any, ...], dtype[_ScalarT]] | int, n_denom: Series | ndarray[tuple[Any, ...], dtype[_ScalarT]] | int, prior_successes: int = 2, prior_failures: int = 2) ndarray[tuple[Any, ...], dtype[_ScalarT]] | float[source]
Compute the Agresti-Coull style standard error for a binomial proportion.
Reference: Agresti, Alan, and Brent A. Coull. Approximate Is Better than ‘Exact’ for Interval Estimation of Binomial Proportions.” The American Statistician, vol. 52, no. 2, 1998, pp. 119-126. JSTOR, www.jstor.org/stable/2685469.
- ax.utils.stats.statstools.inverse_variance_weight(means: ndarray[tuple[Any, ...], dtype[_ScalarT]], variances: ndarray[tuple[Any, ...], dtype[_ScalarT]], conflicting_noiseless: str = 'warn') tuple[float, float][source]
Perform inverse variance weighting.
- Parameters:
means – The means of the observations.
variances – The variances of the observations.
conflicting_noiseless – How to handle the case of multiple observations with zero variance but different means. Options are “warn” (default), “ignore” or “raise”.
- ax.utils.stats.statstools.marginal_effects(df: DataFrame, covariates: list[str] | None = None) DataFrame[source]
This method calculates the relative (in %) change in the outcome achieved by using any individual factor level versus randomizing across all factor levels. It does this by estimating a baseline under the experiment by marginalizing over all factors/levels. For each factor level, then, it conditions on that level for the individual factor and then marginalizes over all levels for all other factors.
- Parameters:
df – Dataframe containing columns named mean and sem. All other columns are assumed to be factors for which to calculate marginal effects.
covariates – List of columns to be used as covariates. If None, then use all columns in df that are not named “mean” or “sem”.
- Returns:
- A dataframe containing columns “Name”, “Level”, “Beta” and “SE”
corresponding to the factor, level, effect and standard error. Results are relativized as percentage changes.
- ax.utils.stats.statstools.positive_part_james_stein(means: ndarray[tuple[Any, ...], dtype[_ScalarT]] | list[float], sems: ndarray[tuple[Any, ...], dtype[_ScalarT]] | list[float]) tuple[ndarray[tuple[Any, ...], dtype[_ScalarT]], ndarray[tuple[Any, ...], dtype[_ScalarT]]][source]
Estimation method for Positive-part James-Stein estimator.
This method takes a vector of K means (y_i) and standard errors (sigma_i) and calculates the positive-part James Stein estimator.
Resulting estimates are the shrunk means and standard errors. The positive part James-Stein estimator shrinks each constituent average to the grand average:
y_i - phi_i * y_i + phi_i * ybar
The variable phi_i determines the amount of shrinkage. For phi_i = 1, mu_hat is equal to ybar (the mean of all y_i), while for phi_i = 0, mu_hat is equal to y_i. It can be shown that restricting phi_i <= 1 dominates the unrestricted estimator, so this method restricts phi_i in this manner. The amount of shrinkage, phi_i, is determined by:
(K - 3) * sigma2_i / s2
That is, less shrinkage is applied when individual means are estimated with greater precision, and more shrinkage is applied when individual means are very tightly clustered together. We also restrict phi_i to never be larger than 1.
The variance of the mean estimator is:
(1 - phi_i) * sigma2_i + phi * sigma2_i / K + 2 * phi_i ** 2 * (y_i - ybar)^2 / (K - 3)
The first term is the variance component from y_i, the second term is the contribution from the mean of all y_i, and the third term is the contribution from the uncertainty in the sum of squared deviations of y_i from the mean of all y_i.
For more information, see https://ax.dev/docs/models.html#empirical-bayes-and-thompson-sampling.
- Parameters:
means – Means of each arm
sems – Standard errors of each arm
- Returns:
Empirical Bayes estimate of each arm’s mean sem_i: Empirical Bayes estimate of each arm’s sem
- Return type:
mu_hat_i
Model Fit Metrics
- class ax.utils.stats.model_fit_stats.ModelFitMetricDirection(*values)[source]
Bases:
StrEnumModel fit metric directions.
- MAXIMIZE = 'maximize'
- MINIMIZE = 'minimize'
- class ax.utils.stats.model_fit_stats.ModelFitMetricProtocol(*args, **kwargs)[source]
Bases:
ProtocolStructural type for model fit metrics.
- ax.utils.stats.model_fit_stats.coefficient_of_determination(y_obs: ndarray[tuple[Any, ...], dtype[_ScalarT]], y_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]], se_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]] | None = None, eps: float = 1e-12) float[source]
Computes coefficient of determination, the proportion of variance in y_obs accounted for by predictions y_pred.
- Parameters:
y_obs – An array of observations for a single metric.
y_pred – An array of the predicted values corresponding to y_obs.
se_pred – Not used, kept for API compatibility.
eps – A small constant to add to the denominator for numerical stability.
- Returns:
The scalar coefficient of determination, “R squared”.
- ax.utils.stats.model_fit_stats.compute_model_fit_metrics(y_obs: Mapping[str, ndarray[tuple[Any, ...], dtype[_ScalarT]]], y_pred: Mapping[str, ndarray[tuple[Any, ...], dtype[_ScalarT]]], se_pred: Mapping[str, ndarray[tuple[Any, ...], dtype[_ScalarT]]], fit_metrics_dict: Mapping[str, ModelFitMetricProtocol]) dict[str, dict[str, float]][source]
Computes the model fit metrics for each experimental metric in the input dicts.
- Parameters:
y_obs – A dictionary mapping from experimental metric name to observed values.
y_pred – A dictionary mapping from experimental metric name to predicted values.
se_pred – A dictionary mapping from experimental metric name to predicted standard errors.
fit_metrics_dict – A dictionary mapping from model fit metric name to a ModelFitMetricProtocol function that evaluates a model fit metric.
- Returns:
A nested dictionary mapping from model fit and experimental metric names to their corresponding model fit metrics values.
- ax.utils.stats.model_fit_stats.entropy_of_observations(y_obs: ndarray[tuple[Any, ...], dtype[_ScalarT]], y_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]], se_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]], bandwidth: float = 0.1) float[source]
Computes the entropy of the observations y_obs using a kernel density estimator. This can be used to quantify how “clustered” the outcomes are. NOTE: y_pred and se_pred are not used, but are required for the API.
- Parameters:
y_obs – An array of observations for a single metric.
y_pred – Unused.
se_pred – Unused.
bandwidth – The kernel bandwidth. Defaults to 0.1, which is a reasonable value for standardized outcomes y_obs. The rank ordering of the results on a set of y_obs data sets is not generally sensitive to the bandwidth, if it is held fixed across the data sets. The absolute value of the results however changes significantly with the bandwidth.
- Returns:
The scalar entropy of the observations.
- ax.utils.stats.model_fit_stats.mean_of_the_standardized_error(y_obs: ndarray[tuple[Any, ...], dtype[_ScalarT]], y_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]], se_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]]) float[source]
Computes the mean of the error standardized by the predictive standard deviation of the model se_pred. If the model makes good predictions and its uncertainty is quantified well, should be close to 0 and be normally distributed.
NOTE: This assumes that se_pred is the predictive standard deviation of the observations of the objective y, not the predictive standard deviation of the objective f itself. In practice, this will matter for very noisy observations.
- Parameters:
y_obs – An array of observations for a single metric.
y_pred – An array of the predicted values corresponding to y_obs.
se_pred – An array of the standard errors of the predicted values.
- Returns:
The scalar mean of the standardized error.
- ax.utils.stats.model_fit_stats.std_of_the_standardized_error(y_obs: ndarray[tuple[Any, ...], dtype[_ScalarT]], y_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]], se_pred: ndarray[tuple[Any, ...], dtype[_ScalarT]]) float[source]
Standard deviation of the error standardized by the predictive standard deviation of the model se_pred. If the uncertainty is quantified well, should be close to 1.
NOTE: This assumes that se_pred is the predictive standard deviation of the observations of the objective y, not the predictive standard deviation of the objective f itself. In practice, this will matter for very noisy observations.
- Parameters:
y_obs – An array of observations for a single metric.
y_pred – An array of the predicted values corresponding to y_obs.
se_pred – An array of the standard errors of the predicted values.
- Returns:
The scalar standard deviation of the standardized error.
Testing
Backend Orchestrator
Backend Simulator
- class ax.utils.testing.backend_simulator.BackendSimulator(options: BackendSimulatorOptions | None = None, queued: list[SimTrial] | None = None, running: list[SimTrial] | None = None, failed: list[SimTrial] | None = None, completed: list[SimTrial] | None = None, verbose_logging: bool = True)[source]
Bases:
BaseSimulator for a backend deployment with concurrent dispatch and a queue.
- get_sim_trial_by_index(trial_index: int) SimTrial | None[source]
Get a
SimTrialbytrial_index.- Parameters:
trial_index – The index of the trial to return.
- Returns:
A
SimTrialwith the indextrial_indexor None if not found.
- lookup_trial_index_status(trial_index: int) TrialStatus[source]
Lookup the trial status of a
trial_index.- Parameters:
trial_index – The index of the trial to check.
- Returns:
A
TrialStatus.
- new_trial(trial: SimTrial, status: TrialStatus) None[source]
Register a trial into the simulator.
- Parameters:
trial – A new trial to add.
status – The status of the new trial, either STAGED (add to
self._queued) or RUNNING (add toself._running).
- run_trial(trial_index: int, runtime: float) None[source]
Run a simulated trial.
- Parameters:
trial_index – The index of the trial (usually the Ax trial index)
runtime – The runtime of the simulation. Typically sampled from the runtime model of a simulation model.
Internally, the runtime is scaled by the time_scaling factor, so that the simulation can run arbitrarily faster than the underlying evaluation.
- state() BackendSimulatorState[source]
Return a
BackendSimulatorStatecontaining the state of the simulator.
- status() SimStatus[source]
Return the internal status of the simulator.
- Returns:
A
SimStatusobject representing the current simulator state.
- class ax.utils.testing.backend_simulator.BackendSimulatorOptions(max_concurrency: int = 1, time_scaling: float = 1.0, failure_rate: float = 0.0, internal_clock: float | None = None, use_update_as_start_time: bool = False)[source]
Bases:
objectSettings for the BackendSimulator.
- Parameters:
max_concurrency – The maximum number of trials that can be run in parallel.
time_scaling – The factor to scale down the runtime of the tasks by. If
runtimeis the actual runtime of a trial, the simulation time will beruntime / time_scaling.failure_rate – The rate at which the trials are failing. For now, trials fail independently with at coin flip based on that rate.
internal_clock – The initial state of the internal clock. If None, the simulator uses
time.time()as the clock.use_update_as_start_time – Whether the start time of a new trial should be logged as the current time (at time of update) or end time of previous trial. This makes sense when using the internal clock and the BackendSimulator is simulated forward by an external process (such as Orchestrator).
- class ax.utils.testing.backend_simulator.BackendSimulatorState(options: BackendSimulatorOptions, verbose_logging: bool, queued: list[dict[str, float | None]], running: list[dict[str, float | None]], failed: list[dict[str, float | None]], completed: list[dict[str, float | None]])[source]
Bases:
objectState of the BackendSimulator.
- Parameters:
options – The BackendSimulatorOptions associated with this simulator.
verbose_logging – Whether the simulator is using verbose logging.
queued – Currently queued trials.
running – Currently running trials.
failed – Currently failed trials.
completed – Currently completed trials.
- options: BackendSimulatorOptions
- class ax.utils.testing.backend_simulator.SimStatus(queued: list[int], running: list[int], failed: list[int], time_remaining: list[float], completed: list[int])[source]
Bases:
objectContainer for status of the simulation.
- class ax.utils.testing.backend_simulator.SimTrial(trial_index: int, sim_runtime: float, sim_start_time: float | None = None, sim_queued_time: float | None = None, sim_completed_time: float | None = None)[source]
Bases:
objectContainer for the simulation tasks.
Benchmark Stubs
- class ax.benchmark.testing.benchmark_stubs.DeterministicGenerationNode(search_space: SearchSpace)[source]
Bases:
ExternalGenerationNodeA GenerationNode that explores a discrete search space with one parameter deterministically.
- get_next_candidate(pending_parameters: list[dict[str, None | str | bool | float | int]]) dict[str, None | str | bool | float | int][source]
Get the parameters for the next candidate configuration to evaluate.
- Parameters:
pending_parameters – A list of parameters of the candidates pending evaluation. This is often used to avoid generating duplicate candidates.
- Returns:
A dictionary mapping parameter names to parameter values for the next candidate suggested by the method.
- update_generator_state(experiment: Experiment, data: Data) None[source]
A method used to update the state of the generator. This includes any models, predictors or any other custom state used by the generation node. This method will be called with the up-to-date experiment and data before
get_next_candidateis called to generate the next trial(s). Note thatget_next_candidatemay be called multiple times (to generate multiple candidates) after a call toupdate_generator_state.- Parameters:
experiment – The
Experimentobject representing the current state of the experiment. The key properties includestrials,search_space, andoptimization_config. The data is provided as a separate arg.data – The data / metrics collected on the experiment so far.
- class ax.benchmark.testing.benchmark_stubs.DummyTestFunction(*, outcome_names: list[str] = <factory>, n_steps: int = 1, num_outcomes: int = 1, dim: int = 6)[source]
Bases:
BenchmarkTestFunction
- ax.benchmark.testing.benchmark_stubs.get_adapter(experiment: Experiment) TorchAdapter[source]
Create a generic adapter for testing different surrogate model types.
- ax.benchmark.testing.benchmark_stubs.get_aggregated_benchmark_result() AggregatedBenchmarkResult[source]
- ax.benchmark.testing.benchmark_stubs.get_async_benchmark_method(early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, max_pending_trials: int = 2) BenchmarkMethod[source]
- ax.benchmark.testing.benchmark_stubs.get_async_benchmark_problem(map_data: bool, step_runtime_fn: TBenchmarkStepRuntimeFunction | None = None, n_steps: int = 1, lower_is_better: bool = False, report_inference_value_as_trace: bool = False, num_objectives: int = 1, num_constraints: int = 0) BenchmarkProblem[source]
Create an early-stopping benchmark problem with MAP_KEY data.
- Parameters:
map_data – Whether to use map metrics (required for early stopping).
step_runtime_fn – Optional runtime function for steps.
n_steps – Number of steps per trial.
lower_is_better – Whether lower values are better (for SOO).
report_inference_value_as_trace – Whether to report inference trace.
num_objectives – Number of objectives (1 for SOO, >1 for MOO).
num_constraints – Number of outcome constraints to add.
- Returns:
A BenchmarkProblem suitable for early-stopping evaluation.
- ax.benchmark.testing.benchmark_stubs.get_benchmark_time_varying_metric() BenchmarkTimeVaryingMetric[source]
- ax.benchmark.testing.benchmark_stubs.get_discrete_search_space(n_values: int = 20) SearchSpace[source]
- ax.benchmark.testing.benchmark_stubs.get_jenatton_arm(i: int) Arm[source]
- Parameters:
int. (i Non-negative)
- ax.benchmark.testing.benchmark_stubs.get_mock_lcbench_data() LCBenchData[source]
Used for mocking out load_lcbench_data to avoid downloading data from the internet.
- ax.benchmark.testing.benchmark_stubs.get_multi_objective_benchmark_problem(observe_noise_sd: bool = False, num_trials: int = 4, test_problem_class: type[BraninCurrin] = <class 'botorch.test_functions.multi_objective.BraninCurrin'>, report_inference_value_as_trace: bool = False) BenchmarkProblem[source]
- ax.benchmark.testing.benchmark_stubs.get_saas_adapter(experiment: Experiment) TorchAdapter[source]
Create an adapter with SaasFullyBayesianSingleTaskGP model.
- ax.benchmark.testing.benchmark_stubs.get_single_objective_benchmark_problem(observe_noise_sd: bool = False, num_trials: int = 4, test_problem_kwargs: dict[str, Any] | None = None, report_inference_value_as_trace: bool = False, noise_std: float | dict[str, float] = 0.0, status_quo_params: dict[str, None | str | bool | float | int] | None = None) BenchmarkProblem[source]
Core Stubs
- class ax.utils.testing.core_stubs.CustomTestMetric(name: str, test_attribute: str, lower_is_better: bool | None = None)[source]
Bases:
Metric
- class ax.utils.testing.core_stubs.DummyEarlyStoppingStrategy(early_stop_trials: dict[int, str | None] | None = None)[source]
Bases:
BaseEarlyStoppingStrategy
- class ax.utils.testing.core_stubs.DummyGlobalStoppingStrategy(min_trials: int, trial_to_stop: int)[source]
Bases:
BaseGlobalStoppingStrategyA dummy Global Stopping Strategy which stops the optimization after a pre-specified number of trials are completed.
- class ax.utils.testing.core_stubs.TestTrial(experiment: core.experiment.Experiment, trial_type: str | None = None, ttl_seconds: int | None = None, index: int | None = None)[source]
Bases:
BaseTrialTrial class to test unsupported trial type error
- add_arm(arm: Arm, candidate_metadata: dict[str, Any] | None = None) Self[source]
Add arm to the trial.
- Returns:
The trial instance.
- add_generator_run(generator_run: GeneratorRun) Self[source]
Add a generator run to the trial.
The arms and weights from the generator run will be merged with the existing arms and weights on the trial, and the generator run object will be linked to the trial for tracking.
- Parameters:
generator_run – The generator run to be added.
- Returns:
The trial instance.
- ax.utils.testing.core_stubs.get_arm_weights1() MutableMapping[Arm, float][source]
- ax.utils.testing.core_stubs.get_arm_weights2() MutableMapping[Arm, float][source]
- ax.utils.testing.core_stubs.get_arms_from_dict(arm_weights_dict: MutableMapping[Arm, float]) list[Arm][source]
- ax.utils.testing.core_stubs.get_batch_trial(abandon_arm: bool = True, experiment: Experiment | None = None, constrain_search_space: bool = True, with_status_quo: bool = True) BatchTrial[source]
- ax.utils.testing.core_stubs.get_botorch_model_with_default_acquisition_class() BoTorchGenerator[source]
- ax.utils.testing.core_stubs.get_botorch_model_with_surrogate_spec(with_covar_module: bool = True) BoTorchGenerator[source]
- ax.utils.testing.core_stubs.get_branin_data(trial_indices: Iterable[int] | None = None, trials: Iterable[Trial] | None = None, metrics: Iterable[str] | None = None) Data[source]
- ax.utils.testing.core_stubs.get_branin_data_batch(batch: BatchTrial, fill_vals: dict[str, float] | None = None, metrics: list[str] | None = None) Data[source]
- ax.utils.testing.core_stubs.get_branin_data_multi_objective(trial_indices: Iterable[int] | None = None, arm_names: Iterable[str] | None = None, outcomes: Sequence[str] | None = None) Data[source]
- ax.utils.testing.core_stubs.get_branin_experiment(has_optimization_config: bool = True, with_batch: bool = False, with_trial: bool = False, with_status_quo: bool = False, status_quo_unknown_parameters: bool = False, with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, with_str_choice_param: bool = False, with_derived_parameter: bool = False, with_parameter_constraint: bool = False, search_space: SearchSpace | None = None, minimize: bool = False, named: bool = True, num_trial: int = 1, num_batch_trial: int = 1, with_completed_batch: bool = False, with_completed_trial: bool = False, num_arms_per_trial: int = 15, with_relative_constraint: bool = False, with_absolute_constraint: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_branin_experiment_with_multi_objective(has_optimization_config: bool = True, has_objective_thresholds: bool = False, with_batch: bool = False, with_status_quo: bool = False, status_quo_unknown_parameters: bool = False, with_fidelity_parameter: bool = False, num_objectives: int = 2, with_trial: bool = False, num_trial: int = 1, with_completed_trial: bool = False, with_completed_batch: bool = False, with_relative_constraint: bool = False, with_absolute_constraint: bool = False, with_choice_parameter: bool = False, with_fixed_parameter: bool = False, with_derived_parameter: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_branin_experiment_with_status_quo_trials(num_sobol_trials: int = 5, multi_objective: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_branin_experiment_with_timestamp_map_metric(with_status_quo: bool = False, noise_sd: float = 0.0, rate: float | None = None, map_tracking_metric: bool = False, decay_function_name: str = 'exp_decay', with_trials_and_data: bool = False, multi_objective: bool = False, has_objective_thresholds: bool = False, bounds: list[float] | None = None, with_choice_parameter: bool = False, with_outcome_constraint: bool = False) Experiment[source]
Returns an experiment with the search space including parameters
- Parameters:
with_status_quo – Whether to include a status quo arm.
noise_sd – Standard deviation of noise to add to the metric.
rate – Rate of decay for the map metric.
map_tracking_metric – Whether to include a tracking map metric.
decay_function_name – Name of the decay function to use.
with_trials_and_data – Whether to include trials and data.
multi_objective – Whether to include multiple objectives and tracking metrics.
has_objective_thresholds – For multi-objective experiments, toggles adding objective thresholds.
bounds – For multi-objective experiments where has_objective_thresholds is True, bounds determines the precise objective thresholds.
with_choice_parameter – Whether to include a choice parameter. If true, x2 will be a ChoiceParameter.
with_outcome_constraint – If True, adds an outcome constraint with an additional non-map Branin metric.
- Returns:
A Branin single or multi-objective experiment with map metrics.
- ax.utils.testing.core_stubs.get_branin_metric(name: str = 'branin', lower_is_better: bool = True) BraninMetric[source]
- ax.utils.testing.core_stubs.get_branin_multi_objective(num_objectives: int = 2) MultiObjective[source]
- ax.utils.testing.core_stubs.get_branin_multi_objective_optimization_config(has_objective_thresholds: bool = False, num_objectives: int = 2, with_relative_constraint: bool = False, with_absolute_constraint: bool = False) MultiObjectiveOptimizationConfig[source]
- ax.utils.testing.core_stubs.get_branin_objective(name: str = 'branin', minimize: bool = False) Objective[source]
- ax.utils.testing.core_stubs.get_branin_optimization_config(minimize: bool = False, with_relative_constraint: bool = False, with_absolute_constraint: bool = False) OptimizationConfig[source]
- ax.utils.testing.core_stubs.get_branin_outcome_constraint(name: str = 'branin') OutcomeConstraint[source]
- ax.utils.testing.core_stubs.get_branin_search_space(with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, with_str_choice_param: bool = False, with_derived_parameter: bool = False, with_parameter_constraint: bool = False, with_fixed_parameter: bool = False) SearchSpace[source]
- ax.utils.testing.core_stubs.get_branin_with_multi_task(with_multi_objective: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_data(metric_name: str = 'ax_test_metric', trial_index: int = 0, num_non_sq_arms: int = 4, include_sq: bool = True) Data[source]
- ax.utils.testing.core_stubs.get_dataset(num_samples: int = 2, d: int = 2, m: int = 2, has_observation_noise: bool = False, feature_names: list[str] | None = None, outcome_names: list[str] | None = None, tkwargs: dict[str, Any] | None = None, seed: int | None = None) SupervisedDataset[source]
Constructs a SupervisedDataset based on the given arguments.
- Parameters:
num_samples – The number of samples in the dataset.
d – The dimension of the features.
m – The number of outcomes.
has_observation_noise – If True, includes Yvar in the dataset.
feature_names – A list of feature names. Defaults to x0, x1…
outcome_names – A list of outcome names. Defaults to y0, y1…
tkwargs – Optional dictionary of tensor kwargs, such as dtype and device.
seed – An optional seed used to generate the data.
- ax.utils.testing.core_stubs.get_equality_parameter_constraint(param_x: str = 'x', param_y: str = 'w') ParameterConstraint[source]
- ax.utils.testing.core_stubs.get_experiment(with_status_quo: bool = True, constrain_search_space: bool = True) Experiment[source]
- ax.utils.testing.core_stubs.get_experiment_with_batch_trial(constrain_search_space: bool = True, with_status_quo: bool = True) Experiment[source]
- ax.utils.testing.core_stubs.get_experiment_with_custom_runner_and_metric(constrain_search_space: bool = True, immutable: bool = False, multi_objective: bool = False, scalarized_objective: bool = False, num_trials: int = 3, has_outcome_constraint: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_experiment_with_map_data_type() Experiment[source]
Returns an experiment with the search space including parameters (with mixture of types) [“w”, “x”, “y”, “z”], a status quo, a single objective optimization config with MapMetric “m1”, and a tracking MapMetric “tracking”, both using the default MapKeyInfo.
- ax.utils.testing.core_stubs.get_experiment_with_observations(observations: Sequence[Sequence[float]], minimize: bool = False, scalarized: bool = False, constrained: bool = False, with_tracking_metrics: bool = False, search_space: SearchSpace | None = None, parameterizations: Sequence[Mapping[str, None | str | bool | float | int]] | None = None, sems: list[list[float]] | None = None, optimization_config: OptimizationConfig | None = None, candidate_metadata: Sequence[dict[str, Any] | None] | None = None, additional_data_columns: Sequence[Mapping[str, Any]] | None = None, signature_suffix: bool = False, status_quo: Arm | None = None) Experiment[source]
- ax.utils.testing.core_stubs.get_experiment_with_scalarized_objective_and_outcome_constraint() Experiment[source]
- ax.utils.testing.core_stubs.get_factorial_experiment(has_optimization_config: bool = True, with_batch: bool = False, with_status_quo: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_factorial_metric(name: str = 'success_metric') FactorialMetric[source]
- ax.utils.testing.core_stubs.get_fixed_parameter(with_dependents: bool = False) FixedParameter[source]
- ax.utils.testing.core_stubs.get_hierarchical_choice_parameter(parameter_type: ParameterType) ChoiceParameter[source]
- ax.utils.testing.core_stubs.get_hierarchical_search_space(with_fixed_parameter: bool = False) SearchSpace[source]
- ax.utils.testing.core_stubs.get_hierarchical_search_space_experiment(num_observations: int = 0, use_map_data: bool = False) Experiment[source]
Create an experiment with a hierarchical search space and optional observations.
- Parameters:
num_observations – The number of trials in the experiment.
use_map_data – Whether data has a column “step.” This flag is for testing the transform MapKeyToFloat, which is applied to the search space only if the experiment’s data has a “step” column.
- Returns:
An experiment with a hierarchical search space and some optional observations.
NOTE: We have fixed the random seeds in the Sobol generator and torch.rand. Otherwise, MapKeyToFloatTransformTest is flaky.
- ax.utils.testing.core_stubs.get_high_dimensional_branin_experiment(with_batch: bool = False, with_status_quo: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_improvement_global_stopping_strategy() ImprovementGlobalStoppingStrategy[source]
- ax.utils.testing.core_stubs.get_large_factorial_search_space(num_levels: int = 10, num_parameters: int = 6) SearchSpace[source]
- ax.utils.testing.core_stubs.get_large_ordinal_search_space(n_ordinal_choice_parameters: int, n_continuous_range_parameters: int) SearchSpace[source]
- ax.utils.testing.core_stubs.get_many_branin_objective_opt_config(n_objectives: int) MultiObjectiveOptimizationConfig[source]
- ax.utils.testing.core_stubs.get_map_metric(name: str, noise_sd: float = 0.0, rate: float | None = None, decay_function_name: str = 'exp_decay') BraninTimestampMapMetric[source]
- ax.utils.testing.core_stubs.get_model_parameter(with_fixed_parameter: bool = False) ChoiceParameter[source]
- ax.utils.testing.core_stubs.get_model_predictions() tuple[dict[str, list[float]], dict[str, dict[str, list[float]]]][source]
- ax.utils.testing.core_stubs.get_model_predictions_per_arm() dict[str, tuple[dict[str, float], dict[str, dict[str, float]] | None]][source]
- ax.utils.testing.core_stubs.get_multi_objective_optimization_config(custom_metric: bool = False, relative: bool = True, outcome_constraint: bool = True) MultiObjectiveOptimizationConfig[source]
- ax.utils.testing.core_stubs.get_multi_type_experiment(add_trial_type: bool = True, add_trials: bool = False, num_arms: int = 10) MultiTypeExperiment[source]
- ax.utils.testing.core_stubs.get_non_failed_arm_names(experiment: Experiment) set[str][source]
Get the names of all arms from non-failed trials.
- ax.utils.testing.core_stubs.get_objective_threshold(metric_name: str = 'm1', bound: float = -0.25, comparison_op: ComparisonOp = ComparisonOp.GEQ) ObjectiveThreshold[source]
- ax.utils.testing.core_stubs.get_offline_experiments() list[Experiment][source]
Returns a List of Experiments which resemble those we see in an offline context. This means single-arm Trial experiments with both single- and multi-objective optimization configs, with data attached.
We also include combinations with and without choice parameters, fixed_parameters, absolute parameter constraints, and relative parameter constraints.
- ax.utils.testing.core_stubs.get_offline_experiments_subset() list[Experiment][source]
Set of 4 experiments that include: 1. Single objective with choice param and param constraint 2. Mulit-objective with objective threshold, absolute constraint, choice param,
and fixed param
Mulit-objective with no thresholds, constraint, or special params
Mulit-objective with objective threshold and fixed param
- ax.utils.testing.core_stubs.get_online_experiments() list[Experiment][source]
Returns a List of Branin Experiments which resemble those we see in an online context. This means BatchTrial experiments with both single- and multi-objective optimization configs and with data attached and at least one trial in a CANDIDATE state.
We also include combinations with and without choice parameters, fixed_parameters, absolute parameter constraints, and relative parameter constraints.
- ax.utils.testing.core_stubs.get_online_experiments_subset() list[Experiment][source]
Set of 4 experiments includes: 1 single objective exp with choice parameter, parameter constraint, and relative constriant. 3 multi-objective experiments with (a) choice param, fixed param, relative and absolute constraint, (b) fixed param and relative constraint (c) no constraints but both fixed and choice param
- ax.utils.testing.core_stubs.get_online_sobol_mbm_generation_strategy() GenerationStrategy[source]
Constructs a GenerationStrategy with Sobol and MBM nodes for simulating online optimization.
- ax.utils.testing.core_stubs.get_optimization_config(outcome_constraint: bool = True, relative: bool = True) OptimizationConfig[source]
- ax.utils.testing.core_stubs.get_optimization_config_no_constraints(minimize: bool = False) OptimizationConfig[source]
- ax.utils.testing.core_stubs.get_outcome_constraint(metric: Metric | None = None, relative: bool = True, bound: float = -0.25) OutcomeConstraint[source]
- ax.utils.testing.core_stubs.get_parameter_constraint(param_x: str = 'x', param_y: str = 'w') ParameterConstraint[source]
- ax.utils.testing.core_stubs.get_pausing_criterion() list[PausingCriterion][source]
Returns a list of PausingCriterion for testing.
- ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy() PercentileEarlyStoppingStrategy[source]
- ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy_with_non_objective_metric_signature() PercentileEarlyStoppingStrategy[source]
- ax.utils.testing.core_stubs.get_scalarized_outcome_constraint() ScalarizedOutcomeConstraint[source]
- ax.utils.testing.core_stubs.get_search_space(constrain_search_space: bool = True) SearchSpace[source]
- ax.utils.testing.core_stubs.get_search_space_for_range_value(min: float = 3.0, max: float = 6.0) SearchSpace[source]
- ax.utils.testing.core_stubs.get_search_space_for_range_values(min: float = 3.0, max: float = 6.0, parameter_names: list[str] | None = None) SearchSpace[source]
- ax.utils.testing.core_stubs.get_search_space_with_choice_parameters(num_ordered_parameters: int = 2, num_unordered_choices: int = 5) SearchSpace[source]
- ax.utils.testing.core_stubs.get_status_quo_branin(with_fidelity_parameter: bool = False, with_str_choice_param: bool = False, with_derived_parameter: bool = False, with_fixed_parameter: bool = False, status_quo_unknown_parameters: bool = False) Arm[source]
- ax.utils.testing.core_stubs.get_surrogate_spec_with_inputs(model_class: type[Model] | None = None, covar_module_class: type[Kernel] | None = None) SurrogateSpec[source]
- ax.utils.testing.core_stubs.get_test_map_data_experiment(num_trials: int, num_fetches: int, num_complete: int, map_tracking_metric: bool = False, multi_objective: bool = False, bounds: list[float] | None = None, has_objective_thresholds: bool = False) Experiment[source]
- ax.utils.testing.core_stubs.get_threshold_early_stopping_strategy() ThresholdEarlyStoppingStrategy[source]
- ax.utils.testing.core_stubs.get_trial_based_criterion() list[TransitionCriterion][source]
Returns a list of trial-based TransitionCriteria for testing.
- ax.utils.testing.core_stubs.get_weights_from_dict(arm_weights_dict: MutableMapping[Arm, float]) list[float][source]
- ax.utils.testing.core_stubs.run_branin_experiment_with_generation_strategy(generation_strategy: GenerationStrategy, num_trials: int = 6, kwargs_for_get_branin_experiment: dict[str, Any] | None = None) Experiment[source]
Gets a Branin experiment using any given kwargs and runs num_trials trials using the given generation strategy.
Modeling Stubs
- ax.utils.testing.modeling_stubs.check_sobol_node(test_case: TestCase, gs: GenerationStrategy, expected_num_trials: int, expected_min_trials_observed: int | None = None) None[source]
Helper to check common Sobol node properties.
- Parameters:
test_case – The test case instance for assertions.
gs – The generation strategy to check.
expected_num_trials – The expected number of trials that need to be generated before the transition to the next node.
expected_min_trials_observed – The expected number of trial that needs to be observed (i.e., completed) before the transition to the next node. If None, the check is skipped.
- ax.utils.testing.modeling_stubs.get_default_generation_strategy_at_MBM_node(experiment: Experiment) GenerationStrategy[source]
- ax.utils.testing.modeling_stubs.get_generation_strategy(with_experiment: bool = False, with_generation_nodes: bool = False) GenerationStrategy[source]
- ax.utils.testing.modeling_stubs.get_legacy_list_surrogate_generation_step_as_dict() dict[str, Any][source]
For use ensuring backwards compatibility loading the now deprecated ListSurrogate.
- ax.utils.testing.modeling_stubs.get_observation1(first_metric_signature: str = 'a', second_metric_signature: str = 'b') Observation[source]
- ax.utils.testing.modeling_stubs.get_observation1trans(first_metric_signature: str = 'a', second_metric_signature: str = 'b') Observation[source]
- ax.utils.testing.modeling_stubs.get_observation2(first_metric_signature: str = 'a', second_metric_signature: str = 'b') Observation[source]
- ax.utils.testing.modeling_stubs.get_surrogate_as_dict() dict[str, Any][source]
For use ensuring backwards compatibility when loading Surrogate with input_transform and outcome_transform kwargs.
- ax.utils.testing.modeling_stubs.get_surrogate_generation_node() GenerationNode[source]
Returns a GenerationNode with surrogate configuration for testing.
- ax.utils.testing.modeling_stubs.get_surrogate_generation_step() GenerationStep[source]
Returns a GenerationStep with surrogate configuration for testing.
Note: This is kept for backward compatibility testing. New code should use get_surrogate_generation_node() instead.
- ax.utils.testing.modeling_stubs.get_surrogate_spec_as_dict(model_class: str | None = None, with_legacy_input_transform: bool = False) dict[str, Any][source]
For use ensuring backwards compatibility when loading SurrogateSpec with input_transform and outcome_transform kwargs.
- ax.utils.testing.modeling_stubs.sobol_gpei_generation_node_gs(with_model_selection: bool = False, with_auto_transition: bool = False, with_previous_node: bool = False, with_input_constructors_all_n: bool = False, with_input_constructors_remaining_n: bool = False, with_input_constructors_repeat_n: bool = False, with_input_constructors_target_trial: bool = False, with_unlimited_gen_mbm: bool = False, with_trial_type: bool = False, with_is_SOO_transition: bool = False) GenerationStrategy[source]
Returns a basic SOBOL+MBM GS using GenerationNodes for testing.
- Parameters:
with_model_selection – If True, will add a second GeneratorSpec in the MBM node. This can be used for testing model selection.
Preference Stubs
- ax.utils.testing.preference_stubs.experimental_metric_eval(parameters: dict[str, Any], metric_names: list[str]) dict[str, Mapping[str, int | float | floating | integer | tuple[int | float | floating | integer, int | float | floating | integer | None]] | int | float | floating | integer | tuple[int | float | floating | integer, int | float | floating | integer | None] | Sequence[tuple[float, Mapping[str, int | float | floating | integer | tuple[int | float | floating | integer, int | float | floating | integer | None]]]]][source]
evaluating experimental metrics
- Parameters:
parameters – Dict of arm name to parameterization
metric_names – List of metric names
- Returns:
Dict of arm name to metric name to (mean, sem)
- ax.utils.testing.preference_stubs.get_pbo_experiment(num_parameters: int = 2, num_experimental_metrics: int = 3, parameter_names: list[str] | None = None, tracking_metric_names: list[str] | None = None, num_experimental_trials: int = 3, num_preference_trials: int = 3, num_preference_trials_w_repeated_arm: int = 5, include_sq: bool = True, partial_data: bool = False, unbounded_search_space: bool = False, experiment_name: str = 'pref_experiment', optimization_config: OptimizationConfig | None = None) Experiment[source]
Create synthetic preferential BO experiment
- ax.utils.testing.preference_stubs.pairwise_pref_metric_eval(parameters: dict[str, dict[str, None | str | bool | float | int]]) dict[str, Mapping[str, int | float | floating | integer | tuple[int | float | floating | integer, int | float | floating | integer | None]] | int | float | floating | integer | tuple[int | float | floating | integer, int | float | floating | integer | None] | Sequence[tuple[float, Mapping[str, int | float | floating | integer | tuple[int | float | floating | integer, int | float | floating | integer | None]]]]][source]
evaluating pairwise comparisons using utility_func
Utils Testing Stubs
- ax.utils.testing.utils_testing_stubs.get_backend_simulator_with_trials() BackendSimulator[source]
Mocking
- ax.utils.testing.mock.minimal_optimize_with_nsgaii(*args: Any, **kwargs: Any) tuple[Tensor, Tensor][source]
- ax.utils.testing.mock.mock_botorch_optimize(f: Callable) Callable[source]
Wraps f in mock_botorch_optimize_context_manager for use as a decorator.
- ax.utils.testing.mock.mock_botorch_optimize_context_manager(force: bool = False) Generator[None, None, None][source]
A context manager that uses mocks to speed up optimization for testing. Currently, the primary tactic is to force the underlying scipy methods to stop after just one iteration.
This context manager uses BoTorch’s mock_optimize_context_manager, and adds some additional mocks that are not possible to cover in BoTorch due to the need to mock the functions where they are used.
- Parameters:
force – If True will not raise an AssertionError if no mocks are called. USE RESPONSIBLY.
Test Init Files
Torch Stubs
Utils
- ax.utils.testing.utils.run_trials_with_gs(experiment: Experiment, gs: GenerationStrategy, num_trials: int) None[source]
Runs and completes num_trials trials for the given experiment with the given GS. The trials are completed with random metric values between 0 and 1.
- Parameters:
experiment – The experiment to run trials on. Must have an optimization config.
gs – The generation strategy to use.
num_trials – The number of trials to run.