Source code for ax.api.utils.instantiation.from_config

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


from typing import cast

import numpy as np
from ax.api.configs import (
    ChoiceParameterConfig,
    DerivedParameterConfig,
    RangeParameterConfig,
)
from ax.core.parameter import (
    ChoiceParameter,
    DerivedParameter,
    FixedParameter,
    Parameter,
    ParameterType as CoreParameterType,
    RangeParameter,
)
from ax.core.types import TParamValue
from ax.exceptions.core import UserInputError


[docs] def parameter_from_config( config: RangeParameterConfig | ChoiceParameterConfig | DerivedParameterConfig, ) -> Parameter: """ Create a RangeParameter, ChoiceParameter, FixedParameter or DerivedParameter from a ParameterConfig. """ if isinstance(config, RangeParameterConfig): lower, upper = config.bounds # TODO[mpolson64] Add support for RangeParameterConfig.step_size native to # RangeParameter instead of converting to ChoiceParameter if (step_size := config.step_size) is not None: if not (config.scaling == "linear" or config.scaling is None): raise UserInputError( "Non-linear parameter scaling is not supported when using " "step_size." ) remainder = (upper - lower) % step_size # Use tolerance-based comparison to handle floating point precision issues if not np.isclose(remainder, 0) and not np.isclose(remainder, step_size): raise UserInputError( "The range of the parameter must be evenly divisible by the " "step size." ) return ChoiceParameter( name=config.name, parameter_type=_parameter_type_converter(config.parameter_type), values=[*np.arange(lower, upper + step_size, step_size)], is_ordered=True, ) return RangeParameter( name=config.name, parameter_type=_parameter_type_converter(config.parameter_type), lower=lower, upper=upper, log_scale=config.scaling == "log", ) elif isinstance(config, DerivedParameterConfig): return DerivedParameter( name=config.name, parameter_type=_parameter_type_converter(config.parameter_type), expression_str=config.expression_str, ) else: # If there is only one value, create a FixedParameter instead of a # ChoiceParameter if len(config.values) == 1: return FixedParameter( name=config.name, parameter_type=_parameter_type_converter(config.parameter_type), value=config.values[0], dependents=cast( dict[TParamValue, list[str]] | None, config.dependent_parameters, ), ) return ChoiceParameter( name=config.name, parameter_type=_parameter_type_converter(config.parameter_type), values=cast(list[TParamValue], config.values), is_ordered=config.is_ordered, dependents=cast( dict[TParamValue, list[str]] | None, config.dependent_parameters, ), sort_values=config.parameter_type != "str", # Matches default behavior. )
def _parameter_type_converter(parameter_type: str) -> CoreParameterType: """ Convert from an API ParameterType to a core Ax ParameterType. """ if parameter_type == "bool": return CoreParameterType.BOOL elif parameter_type == "float": return CoreParameterType.FLOAT elif parameter_type == "int": return CoreParameterType.INT elif parameter_type == "str": return CoreParameterType.STRING else: raise UserInputError(f"Unsupported parameter type {parameter_type}.")