# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import partial
from typing import Any
import torch
from ax.utils.sensitivity.derivative_gp import posterior_derivative
from ax.utils.sensitivity.fixed_feature_model import (
FixedFeatureModel,
prepare_fixed_feature_inputs,
)
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import unnormalize
from gpytorch.distributions import MultivariateNormal
from pyre_extensions import assert_is_instance, none_throws
[docs]
def sample_discrete_parameters(
input_mc_samples: torch.Tensor,
discrete_features: None | list[int],
bounds: torch.Tensor,
num_mc_samples: int,
) -> torch.Tensor:
r"""Samples the input parameters uniformly at random for the discrete features.
Args:
input_mc_samples: The input mc samples tensor to be modified.
discrete_features: A list of integers (or None) of indices corresponding
to discrete features.
bounds: The parameter bounds.
num_mc_samples: The number of Monte Carlo grid samples.
Returns:
A modified input mc samples tensor.
"""
if discrete_features is None:
return input_mc_samples
all_low = bounds[0, discrete_features].to(dtype=torch.int).tolist()
all_high = (bounds[1, discrete_features]).to(dtype=torch.int).tolist()
for i, low, high in zip(discrete_features, all_low, all_high):
randint = partial(torch.randint, low=low, high=high + 1)
input_mc_samples[:, i] = randint(size=torch.Size([num_mc_samples]))
return input_mc_samples
[docs]
class GpDGSMGpMean:
mean_gradients: torch.Tensor | None = None
bootstrap_indices: torch.Tensor | None = None
mean_gradients_btsp: list[torch.Tensor] | None = None
def __init__(
self,
model: Model,
bounds: torch.Tensor,
derivative_gp: bool = False,
kernel_type: str | None = None,
Y_scale: float = 1.0,
num_mc_samples: int = 10**4,
input_qmc: bool = False,
dtype: torch.dtype = torch.double,
num_bootstrap_samples: int = 1,
discrete_features: list[int] | None = None,
) -> None:
r"""Computes three types of derivative based measures:
the gradient, the gradient square and the gradient absolute measures.
Args:
model: A BoTorch model.
bounds: Parameter bounds over which to evaluate model sensitivity.
derivative_gp: If true, the derivative of the GP is used to compute
the gradient instead of backward.
kernel_type: Takes "rbf" or "matern", set only if `derivative_gp` is true.
Y_scale: Scale the derivatives by this amount, to undo scaling
done on the training data.
num_mc_samples: The number of MonteCarlo grid samples
input_qmc: If True, a qmc Sobol grid is use instead of uniformly random.
dtype: Can be provided if the GP is fit to data of type `torch.float`.
num_bootstrap_samples: If higher than 1, the method will compute the
dgsm measure `num_bootstrap_samples` times by selecting subsamples
from the `input_mc_samples` and return the variance and standard error
across all computed measures.
discrete_features: If specified, the inputs associated with the indices in
this list are generated using an integer-valued uniform distribution,
rather than the default (pseudo-)random continuous uniform distribution.
"""
# Use bounds to determine dimension - this is more robust than train_inputs
# when using FixedFeatureModel wrappers that reduce the effective dimension
self.dim: int = bounds.shape[-1]
self.derivative_gp = derivative_gp
self.kernel_type = kernel_type
self.bootstrap: bool = num_bootstrap_samples > 1
# deduct 1 because the first is meant to be the full grid
self.num_bootstrap_samples: int = num_bootstrap_samples - 1
self.torch_device: torch.device = bounds.device
if self.derivative_gp and (self.kernel_type is None):
raise ValueError("Kernel type has to be specified to use derivative GP")
self.num_mc_samples = num_mc_samples
if input_qmc:
self.input_mc_samples: torch.Tensor = (
draw_sobol_samples(bounds=bounds, n=num_mc_samples, q=1, seed=1234)
.squeeze(1)
.to(dtype)
)
else:
self.input_mc_samples = unnormalize(
torch.rand(
num_mc_samples, self.dim, dtype=dtype, device=self.torch_device
),
bounds=bounds,
)
# uniform integral distribution for discrete features
self.input_mc_samples = sample_discrete_parameters(
input_mc_samples=self.input_mc_samples,
discrete_features=discrete_features,
bounds=bounds,
num_mc_samples=num_mc_samples,
)
if self.derivative_gp:
posterior = posterior_derivative(
model, self.input_mc_samples, none_throws(self.kernel_type)
)
self._compute_gradient_quantities(posterior, Y_scale)
else:
self.input_mc_samples.requires_grad = True
self._compute_mean_gradients(model, Y_scale)
def _compute_gradient_quantities(
self, posterior: GPyTorchPosterior | MultivariateNormal, Y_scale: float
) -> None:
if self.derivative_gp:
self.mean_gradients = (
assert_is_instance(posterior.mean, torch.Tensor) * Y_scale
)
else:
predictive_mean = posterior.mean
torch.sum(predictive_mean).backward()
self.mean_gradients = (
assert_is_instance(self.input_mc_samples.grad, torch.Tensor) * Y_scale
)
self._compute_bootstrap()
def _compute_mean_gradients(self, model: Model, Y_scale: float) -> None:
"""Compute mean gradients with batched posterior to limit memory.
Processes MC samples in chunks to avoid OOM with models that have
large internal batch dimensions (e.g., SaasFullyBayesianSingleTaskGP
with 256 MCMC samples can use ~94 GiB when evaluated unbatched on
10,000 points with autograd enabled).
"""
batch_size = 1024
all_grads = []
for batch in self.input_mc_samples.split(batch_size):
batch_input = batch.detach().requires_grad_(True)
posterior = assert_is_instance(
model.posterior(batch_input), GPyTorchPosterior
)
torch.sum(posterior.mean).backward()
all_grads.append(assert_is_instance(batch_input.grad, torch.Tensor).clone())
del posterior
self.mean_gradients = torch.cat(all_grads, dim=0) * Y_scale
self._compute_bootstrap()
def _compute_bootstrap(self) -> None:
if self.bootstrap:
subset_size = 2
self.bootstrap_indices = torch.randint(
0, self.num_mc_samples, (self.num_bootstrap_samples, subset_size)
)
self.mean_gradients_btsp = [
torch.index_select(
assert_is_instance(self.mean_gradients, torch.Tensor), 0, indices
)
for indices in self.bootstrap_indices
]
[docs]
def aggregation(
self, transform_fun: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
gradients_measure = torch.tensor(
[
torch.mean(transform_fun(none_throws(self.mean_gradients)[:, i]))
for i in range(self.dim)
]
)
if not (self.bootstrap):
return gradients_measure
else:
gradients_measures_btsp = [gradients_measure.unsqueeze(0)]
for b in range(self.num_bootstrap_samples):
gradients_measures_btsp.append(
torch.tensor(
[
torch.mean(
transform_fun(
none_throws(self.mean_gradients_btsp)[b][:, i]
)
)
for i in range(self.dim)
]
).unsqueeze(0)
)
gradients_measures_btsp = torch.cat(gradients_measures_btsp, dim=0)
return (
torch.cat(
[
gradients_measures_btsp.mean(dim=0).unsqueeze(0),
gradients_measures_btsp.var(dim=0).unsqueeze(0),
torch.sqrt(
gradients_measures_btsp.var(dim=0)
/ (self.num_bootstrap_samples + 1)
).unsqueeze(0),
],
dim=0,
)
.t()
.detach()
)
[docs]
def gradient_measure(self) -> torch.Tensor:
r"""Computes the gradient measure:
Returns:
if `self.num_bootstrap_samples > 1`
Tensor: (values, var_mc, stderr_mc) x dim
else
Tensor: (values) x dim
"""
return self.aggregation(torch.as_tensor)
[docs]
def gradient_absolute_measure(self) -> torch.Tensor:
r"""Computes the gradient absolute measure:
Returns:
if `self.num_bootstrap_samples > 1`
Tensor: (values, var_mc, stderr_mc) x dim
else
Tensor: (values) x dim
"""
return self.aggregation(torch.abs)
[docs]
def gradients_square_measure(self) -> torch.Tensor:
r"""Computes the gradient square measure:
Returns:
if `num_bootstrap_samples > 1`
Tensor: (values, var_mc, stderr_mc) x dim
else
Tensor: (values) x dim
"""
return self.aggregation(torch.square)
[docs]
class GpDGSMGpSampling(GpDGSMGpMean):
samples_gradients: torch.Tensor | None = None
samples_gradients_btsp: list[torch.Tensor] | None = None
def _compute_mean_gradients(self, model: Model, Y_scale: float) -> None:
"""Override: sampling needs the full posterior for rsample()."""
posterior = assert_is_instance(
model.posterior(self.input_mc_samples), GPyTorchPosterior
)
self._compute_gradient_quantities(posterior, Y_scale)
def __init__(
self,
model: Model,
bounds: torch.Tensor,
num_gp_samples: int,
derivative_gp: bool = False,
kernel_type: str | None = None,
Y_scale: float = 1.0,
num_mc_samples: int = 10**4,
input_qmc: bool = False,
gp_sample_qmc: bool = False,
dtype: torch.dtype = torch.double,
num_bootstrap_samples: int = 1,
) -> None:
r"""Computes three types of derivative based measures:
the gradient, the gradient square and the gradient absolute measures.
Args:
model: A BoTorch model.
bounds: Parameter bounds over which to evaluate model sensitivity.
num_gp_samples: If method is "GP samples", the number of GP samples has
to be set.
derivative_gp: If true, the derivative of the GP is used to compute the
gradient instead of backward.
kernel_type: Takes "rbf" or "matern", set only if `derivative_gp` is true.
Y_scale: Scale the derivatives by this amount, to undo scaling done on
the training data.
num_mc_samples: The number of Monte Carlo grid samples.
input_qmc: If True, a qmc Sobol grid is used instead of uniformly random.
gp_sample_qmc: If True, the posterior sampling is done using
`SobolQMCNormalSampler`.
dtype: Can be provided if the GP is fit to data of type `torch.float`.
num_bootstrap_samples: If higher than 1, the method will compute the
dgsm measure `num_bootstrap_samples` times by selecting subsamples
from the `input_mc_samples` and return the variance and standard error
across all computed measures.
Returns values of gradient_measure, gradient_absolute_measure and
gradients_square_measure change to the following:
if `num_bootstrap_samples > 1`:
Tensor: (values, var_gp, stderr_gp, var_mc, stderr_mc) x dim
else
Tensor: (values, var_gp, stderr_gp) x dim
"""
self.num_gp_samples = num_gp_samples
self.gp_sample_qmc = gp_sample_qmc
self.num_mc_samples = num_mc_samples
super().__init__(
model=model,
bounds=bounds,
derivative_gp=derivative_gp,
kernel_type=kernel_type,
Y_scale=Y_scale,
num_mc_samples=num_mc_samples,
input_qmc=input_qmc,
dtype=dtype,
num_bootstrap_samples=num_bootstrap_samples,
)
def _compute_gradient_quantities(
self, posterior: Posterior | MultivariateNormal, Y_scale: float
) -> None:
if self.gp_sample_qmc:
sampler = SobolQMCNormalSampler(
sample_shape=torch.Size([self.num_gp_samples]), seed=0
)
samples = sampler(posterior)
else:
samples = posterior.rsample(torch.Size([self.num_gp_samples]))
if self.derivative_gp:
self.samples_gradients = samples * Y_scale
else:
samples_gradients = []
for j in range(self.num_gp_samples):
torch.sum(samples[j]).backward(retain_graph=True)
grad = none_throws(self.input_mc_samples.grad)
samples_gradients.append(deepcopy(grad).unsqueeze(0))
grad.data.zero_()
self.samples_gradients = torch.cat(samples_gradients, dim=0) * Y_scale
if self.bootstrap:
subset_size = 2
self.bootstrap_indices = torch.randint(
0, self.num_mc_samples, (self.num_bootstrap_samples, subset_size)
)
self.samples_gradients_btsp = []
for j in range(self.num_gp_samples):
none_throws(self.samples_gradients_btsp).append(
torch.cat(
[
torch.index_select(
none_throws(self.samples_gradients)[j], 0, indices
).unsqueeze(0)
for indices in none_throws(self.bootstrap_indices)
],
dim=0,
)
)
[docs]
def aggregation(
self, transform_fun: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
gradients_measure_list = []
for j in range(self.num_gp_samples):
gradients_measure_list.append(
torch.tensor(
[
torch.mean(
transform_fun(none_throws(self.samples_gradients)[j][:, i])
)
for i in range(self.dim)
]
).unsqueeze(0)
)
gradients_measure_list = torch.cat(gradients_measure_list, dim=0)
if not (self.bootstrap):
gradients_measure_mean_var = []
for i in range(self.dim):
gradients_measure_mean_var.append(
torch.tensor(
[
torch.mean(gradients_measure_list[:, i]),
torch.var(gradients_measure_list[:, i]),
torch.sqrt(
torch.var(gradients_measure_list[:, i])
/ self.num_gp_samples
),
]
).unsqueeze(0)
)
gradients_measure_mean_var = torch.cat(gradients_measure_mean_var, dim=0)
return gradients_measure_mean_var
else:
gradients_measure_list_btsp = []
for j in range(self.num_gp_samples):
gradients_measure_btsp = [gradients_measure_list[j].unsqueeze(0)] + [
torch.tensor(
[
torch.mean(
transform_fun(
none_throws(self.samples_gradients_btsp)[j][b][:, i]
)
)
for i in range(self.dim)
]
).unsqueeze(0)
for b in range(self.num_bootstrap_samples)
]
gradients_measure_list_btsp.append(
torch.cat(gradients_measure_btsp, dim=0).unsqueeze(0)
)
gradients_measure_list_btsp = torch.cat(gradients_measure_list_btsp, dim=0)
var_per_bootstrap = torch.var(gradients_measure_list_btsp, dim=0)
gp_var = torch.mean(var_per_bootstrap, dim=0)
gp_se = torch.sqrt(gp_var / self.num_gp_samples)
var_per_gp_sample = torch.var(gradients_measure_list_btsp, dim=1)
mc_var = torch.mean(var_per_gp_sample, dim=0)
mc_se = torch.sqrt(mc_var / (self.num_bootstrap_samples + 1))
total_mean = gradients_measure_list_btsp.reshape(-1, self.dim).mean(dim=0)
gradients_measure_mean_vargp_segp_varmc_segp = torch.cat(
[
torch.tensor(
[total_mean[i], gp_var[i], gp_se[i], mc_var[i], mc_se[i]]
).unsqueeze(0)
for i in range(self.dim)
],
dim=0,
)
return gradients_measure_mean_vargp_segp_varmc_segp
[docs]
def compute_derivatives_from_model_list(
model_list: Sequence[Model],
bounds: torch.Tensor,
discrete_features: list[int] | None = None,
fixed_features: dict[int, float] | None = None,
**kwargs: Any,
) -> torch.Tensor:
"""
Computes average derivatives of a list of models on a bounded domain. Estimation
is according to the GP posterior mean function.
Args:
model_list: A list of m botorch.models.model.Model types for which to compute
the average derivative.
bounds: A 2 x d Tensor of lower and upper bounds of the domain of the models.
discrete_features: If specified, the inputs associated with the indices in
this list are generated using an integer-valued uniform distribution,
rather than the default (pseudo-)random continuous uniform distribution.
fixed_features: If specified, a dictionary mapping feature indices to fixed
values. These features will be held constant and their derivatives will
not be computed. The bounds tensor should include all features.
kwargs: Passed along to GpDGSMGpMean.
Returns:
A (m x d') tensor of gradient measures, where d' is the number of non-fixed
features.
"""
# Handle fixed features by reducing bounds and wrapping models
models_to_use: Sequence[Model] | list[FixedFeatureModel] = model_list
if fixed_features is not None and len(fixed_features) > 0:
models_to_use, bounds, discrete_features = prepare_fixed_feature_inputs(
model_list=list(model_list),
bounds=bounds,
discrete_features=discrete_features,
fixed_features=fixed_features,
)
indices = []
for model in models_to_use:
sens_class = GpDGSMGpMean(
model=model, # pyre-ignore[6]: FixedFeatureModel wraps Model
bounds=bounds,
discrete_features=discrete_features,
**kwargs,
)
indices.append(sens_class.gradient_measure())
return torch.stack(indices)