# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import cast
import numpy as np
from ax.api.configs import (
ChoiceParameterConfig,
DerivedParameterConfig,
RangeParameterConfig,
)
from ax.core.parameter import (
ChoiceParameter,
DerivedParameter,
FixedParameter,
Parameter,
ParameterType as CoreParameterType,
RangeParameter,
)
from ax.core.types import TParamValue
from ax.exceptions.core import UserInputError
[docs]
def parameter_from_config(
config: RangeParameterConfig | ChoiceParameterConfig | DerivedParameterConfig,
) -> Parameter:
"""
Create a RangeParameter, ChoiceParameter, FixedParameter or DerivedParameter
from a ParameterConfig.
"""
if isinstance(config, RangeParameterConfig):
lower, upper = config.bounds
# TODO[mpolson64] Add support for RangeParameterConfig.step_size native to
# RangeParameter instead of converting to ChoiceParameter
if (step_size := config.step_size) is not None:
if not (config.scaling == "linear" or config.scaling is None):
raise UserInputError(
"Non-linear parameter scaling is not supported when using "
"step_size."
)
remainder = (upper - lower) % step_size
# Use tolerance-based comparison to handle floating point precision issues
if not np.isclose(remainder, 0) and not np.isclose(remainder, step_size):
raise UserInputError(
"The range of the parameter must be evenly divisible by the "
"step size."
)
return ChoiceParameter(
name=config.name,
parameter_type=_parameter_type_converter(config.parameter_type),
values=[*np.arange(lower, upper + step_size, step_size)],
is_ordered=True,
)
return RangeParameter(
name=config.name,
parameter_type=_parameter_type_converter(config.parameter_type),
lower=lower,
upper=upper,
log_scale=config.scaling == "log",
)
elif isinstance(config, DerivedParameterConfig):
return DerivedParameter(
name=config.name,
parameter_type=_parameter_type_converter(config.parameter_type),
expression_str=config.expression_str,
)
else:
# If there is only one value, create a FixedParameter instead of a
# ChoiceParameter
if len(config.values) == 1:
return FixedParameter(
name=config.name,
parameter_type=_parameter_type_converter(config.parameter_type),
value=config.values[0],
dependents=cast(
dict[TParamValue, list[str]] | None,
config.dependent_parameters,
),
)
return ChoiceParameter(
name=config.name,
parameter_type=_parameter_type_converter(config.parameter_type),
values=cast(list[TParamValue], config.values),
is_ordered=config.is_ordered,
dependents=cast(
dict[TParamValue, list[str]] | None,
config.dependent_parameters,
),
sort_values=config.parameter_type != "str", # Matches default behavior.
)
def _parameter_type_converter(parameter_type: str) -> CoreParameterType:
"""
Convert from an API ParameterType to a core Ax ParameterType.
"""
if parameter_type == "bool":
return CoreParameterType.BOOL
elif parameter_type == "float":
return CoreParameterType.FLOAT
elif parameter_type == "int":
return CoreParameterType.INT
elif parameter_type == "str":
return CoreParameterType.STRING
else:
raise UserInputError(f"Unsupported parameter type {parameter_type}.")