Source code for ax.utils.common.typeutils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, TypeVar

from pyre_extensions import assert_is_instance

T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K")
X = TypeVar("X")
Y = TypeVar("Y")


[docs] def assert_is_instance_optional(val: V | None, typ: type[T]) -> T | None: """ Asserts that the value is an instance of the given type if it is not None. Args: val: the value to check typ: the type to check against Returns: the `val` argument, unchanged """ if val is None: return val return assert_is_instance(val, typ)
[docs] def assert_is_instance_list(old_l: list[V], typ: type[T]) -> list[T]: """ Asserts that all items in a list are instances of the given type. Args: old_l: the list to check typ: the type to check against Returns: the `old_l` argument, unchanged """ return [assert_is_instance(val, typ) for val in old_l]
[docs] def assert_is_instance_dict( d: dict[X, Y], key_type: type[K], val_type: type[V] ) -> dict[K, V]: """ Asserts that all keys and values in the dictionary are instances of the given classes. Args: d: the dictionary to check key_type: the type to check against for keys val_type: the type to check against for values Returns: the `d` argument, unchanged """ new_dict = {} for key, val in d.items(): key = assert_is_instance(key, key_type) val = assert_is_instance(val, val_type) new_dict[key] = val return new_dict
[docs] def assert_is_instance_of_tuple(val: V, typ: tuple[type[V], ...]) -> V: """ Asserts that a value is an instance of any type in a tuple of types. Args: typ: the tuple of types to check against val: the value that we are checking Returns: the `val` argument, unchanged """ if not isinstance(val, typ): raise TypeError(f"Value was not of any type {typ!r}:\n{val!r}") return val
def _argparse_type_encoder(arg: Any) -> type[Any]: """ Transforms arguments passed to `optimizer_argparse.__call__` at runtime to construct the key used for method lookup as `tuple(map(arg_transform, args))`. This custom arg_transform allow type variables to be passed at runtime. """ # Allow type variables to be passed as arguments at runtime return arg if isinstance(arg, type) else type(arg)