Skip to content

pertubators

muppet.components.perturbator.base

Base perturbator classes for applying masks to create perturbed inputs.

This module provides the foundational infrastructure for perturbators in the MUPPET XAI framework. Perturbators implement the second step of the four-step perturbation-based explanation process: generate masks → apply perturbations → calculate attributions → aggregate results.

Perturbators take the original input data and binary masks from explorers, then create perturbed versions of the input by modifying specific regions according to the masks. The perturbation strategy depends on the data modality and explanation method requirements.

The module contains

Perturbator: Abstract base class defining the perturbation interface with automatic batch processing and memory management capabilities. TrainablePerturbator: Extended base class for perturbators that use trainable generators for more sophisticated perturbation strategies.

Mask Convention

All premise types follow the MUPPET binary mask convention: - 0: Preserve the original input value (no perturbation) - 1: Perturb the input value (apply perturbation strategy)

Note

Perturbators work closely with explorers (mask generation) and are consumed by attributors (feature importance calculation). The choice of perturbation strategy significantly impacts explanation quality.

Classes

Perturbator
Perturbator(max_batch_size)

Bases: ABC

Abstract base class for perturbation strategies in XAI methods.

Perturbators apply masks to input data to create perturbed versions for explanation purposes. They define how the original input should be modified based on the perturbation masks generated by explorers.

Perturbators implement the second step of the four-step perturbation-based explanation process: generate masks → apply perturbations → calculate attributions → aggregate results.

The perturbators automatically handle
  • Batch processing with configurable memory limits
  • Out-of-memory error recovery with smaller batch sizes
  • Integration with trainable generators when needed

Initialize the Perturbator with batch processing configuration.

Parameters:

  • max_batch_size (int) –

    Maximum batch size for processing perturbations.

Source code in muppet/components/perturbator/base.py
def __init__(self, max_batch_size: int) -> None:
    """Initialize the Perturbator with batch processing configuration.

    Args:
        max_batch_size (int): Maximum batch size for processing perturbations.
    """
    self.device = None
    self.max_batch_size = max_batch_size
    super().__init__()
Functions
reinitialize
reinitialize()

Reset the perturbator to its initial state.

Clears any internal state or cached data that may affect subsequent perturbation operations.

Source code in muppet/components/perturbator/base.py
def reinitialize(self):
    """Reset the perturbator to its initial state.

    Clears any internal state or cached data that may affect
    subsequent perturbation operations.
    """
    pass
perturbate abstractmethod
perturbate(x, masks)

Apply perturbations to input data using provided masks.

Mask Convention

All perturbators follow the MUPPET perturbation convention: - 0: Preserve the original input value (no perturbation) - 1: Perturb the input value (apply perturbation strategy)

Parameters:

  • x (Tensor) –

    The original input example. Shape (b, *x.shape[1:]).

  • masks (Tensor) –

    Binary masks for perturbation. Shape (N, *mask_shape).

Returns:

  • Tensor

    torch.Tensor: Perturbed examples. Shape (N, *x.shape).

Source code in muppet/components/perturbator/base.py
@abstractmethod
def perturbate(
    self,
    x: torch.Tensor,
    masks: torch.Tensor,
) -> torch.Tensor:
    """Apply perturbations to input data using provided masks.

    Mask Convention:
        All perturbators follow the MUPPET perturbation convention:
        - 0: Preserve the original input value (no perturbation)
        - 1: Perturb the input value (apply perturbation strategy)

    Args:
        x (torch.Tensor): The original input example. Shape (b, *x.shape[1:]).
        masks (torch.Tensor): Binary masks for perturbation. Shape (N, *mask_shape).

    Returns:
        torch.Tensor: Perturbed examples. Shape (N, *x.shape).
    """
    raise NotImplementedError
batch_perturbation staticmethod
batch_perturbation()

A method that acts as a decorator factory. It has access to the instance 'self'.

Source code in muppet/components/perturbator/base.py
@staticmethod
def batch_perturbation():
    """A method that acts as a decorator factory.
    It has access to the instance 'self'.
    """

    def decorator(func):
        """Decorator to automatically handle torch.OutOfMemoryError by
        retrying the function with smaller batches.

        This decorator assumes:
        1. The function is a method of a class.
        2. The second argument (index 1) is the tensor to be batched
        along its first dimension (e.g., 'masks').
        3. The function returns a torch.Tensor that can be concatenated.
        """

        @functools.wraps(func)
        def wrapper(instance_self, x, masks, *args, **kwargs):
            # the tensor to be batched ('masks') from args.
            # This assumes the function signature is like: method(self, x, masks, ...)

            # The initial batch size is the total number of items.
            tensor_to_batch = masks
            n = tensor_to_batch.size(0)
            batch_size = n

            if batch_size >= instance_self.max_batch_size:
                new_batch_size = instance_self.max_batch_size
                logger.debug(
                    f"Reducing batch size from {batch_size} to {new_batch_size} preventively"
                )
                batch_size = new_batch_size

            results = []
            logger.debug(
                f"Attempting to run '{func.__name__}' with batch size {batch_size}..."
            )
            # Process the input tensor in chunks of the current batch_size
            for i in range(0, n, batch_size):
                chunk = tensor_to_batch[i : i + batch_size]

                result_chunk = func(
                    instance_self, x, chunk, *args, **kwargs
                )
                results.append(result_chunk)

            # If all chunks were processed successfully, concatenate and return.
            logger.debug(
                f"Successfully processed '{func.__name__}' with batch size {batch_size}."
            )
            return torch.cat(results, dim=0)

        return wrapper

    return decorator
TrainablePerturbator
TrainablePerturbator(
    generator, train_loader, max_batch_size=100
)

Bases: Perturbator

Perturbator with trainable generator capabilities.

Extends the basic Perturbator to include a trainable generator component that can learn perturbation strategies from training data. These trainable perturbators are designed for perturbation strategies that create plausible perturbations.

It automatically handles generator training if needed and integrates seamlessly with the MUPPET explanation pipeline.

Technical Features
  • Automatic generator training with validation splits and early stopping
  • Seamless integration with various trainable generators

Initialize the TrainablePerturbator with a generator and training data.

Parameters:

  • generator (TrainableGenerator) –

    The generator for creating perturbations.

  • train_loader (DataLoader) –

    Training data for generator training.

  • max_batch_size (int, default: 100 ) –

    Maximum batch size for processing. Defaults to 100.

Source code in muppet/components/perturbator/base.py
def __init__(
    self,
    generator: TrainableGenerator,
    train_loader: DataLoader,
    max_batch_size: int = 100,
) -> None:
    """Initialize the TrainablePerturbator with a generator and training data.

    Args:
        generator (TrainableGenerator): The generator for creating perturbations.
        train_loader (DataLoader): Training data for generator training.
        max_batch_size (int): Maximum batch size for processing. Defaults to 100.
    """
    self.generator = generator
    if not self.generator.is_trained:
        assert train_loader is not None, (
            "If the generator training is required, you must provide the training data!"
        )
        train_losses, test_losses = self.generator.train_generator(
            train_loader=train_loader,
        )
        logger.info(
            "Finished training generator on provided data! Accessible through: 'perturbator.generator'."
        )
        logger.debug(
            "Training loss trend: " + str([i for i in train_losses])
        )
        logger.debug("Testing loss trend: " + str([i for i in test_losses]))

        # set generator state to trained
        self.generator.is_trained = True

    super().__init__(max_batch_size=max_batch_size)

muppet.components.perturbator.scale_feature_generator

Scale-based perturbators for tabular data using statistical generators.

This module implements perturbators that use statistical generators to create scaled perturbations of tabular data. These perturbators are designed for the MUPPET XAI framework's perturbation step, where input features are selectively replaced with generated values based on learned or observed data distributions.

The perturbators in this module specialize in tabular data explanations by leveraging sophisticated generators that understand feature distributions, correlations, and data types. This enables more realistic perturbations compared to simple masking approaches, leading to better explanation quality for tabular models.

The module contains

ScaleFeaturePerturbator: Uses Gaussian-based generators for continuous tabular features with statistical scaling and normalization RandomSamplePerturbator: Employs frequency-based sampling from training data distributions for mixed categorical/numerical features

Key Features
  • Statistical distribution preservation through generator training
  • Mixed data type handling (numerical and categorical)
  • Instance-centered perturbations for local explanations
  • Configurable sampling strategies and scaling approaches
  • Memory-efficient batch processing with automatic size adjustment

These perturbators are essential for tabular explanation methods like LIME and SHAP, where the quality of counterfactual examples directly impacts explanation fidelity and interpretability. They enable realistic "what-if" scenarios by generating plausible alternative feature values.

Classes

ScaleFeaturePerturbator
ScaleFeaturePerturbator(generator, max_batch_size=100)

Bases: Perturbator

Perturbator for tabular data using Gaussian generators.

Specializes in perturbing tabular data by applying statistical generators that maintain feature distributions. Designed for structured data explanations where preserving data realism is crucial.

Perturbator for tabular data based on a Gaussian generator.

This perturbator is designed to modify tabular data by generating controlled variations using a Gaussian-based generator, which has been trained on continuous input features.

Parameters:

  • generator (StandardGaussianTabularGenerator) –

    An instance of a Gaussian generator for tabular data, which, once trained on the input data, generates continuous values based on a Gaussian distribution.

  • max_batch_size (int, default: 100 ) –

    Max Batch size to use. Default to 100.

Source code in muppet/components/perturbator/scale_feature_generator.py
def __init__(
    self,
    generator: StandardGaussianTabularGenerator,
    max_batch_size: int = 100,
) -> None:
    """Perturbator for tabular data based on a Gaussian generator.

    This perturbator is designed to modify tabular data by generating
    controlled variations using a Gaussian-based generator, which has
    been trained on continuous input features.


    Args:
        generator (StandardGaussianTabularGenerator): An instance of a Gaussian
            generator for tabular data, which, once trained on the input data,
            generates continuous values based on a Gaussian distribution.
        max_batch_size (int): Max Batch size to use. Default to 100.
    """
    self.generator = generator
    super().__init__(max_batch_size=max_batch_size)
Functions
perturbate
perturbate(x, masks)

Perturbs the input tensor x using the provided masks.

This method applies perturbations to the input x based on the given masks. The masks determine which features of the input should be perturbed (1 for perturbed, 0 for not perturbed). The generator is used to produce substitute values for the masked features.

Parameters:

  • x (Tensor) –

    The input tensor to be perturbed, with shape (1, f), where f is the number of features in the data.

  • masks (Tensor) –

    A tensor of masks containing 0s and 1s to determine which features in x will be perturbed. Its shape is (b, *shape), where shape corresponds to the shape of x.

Returns:

  • Tensor

    torch.Tensor: A tensor containing the generated perturbed values,

  • Tensor

    with shape (number_of_masks, *x.shape).

Source code in muppet/components/perturbator/scale_feature_generator.py
def perturbate(
    self,
    x: torch.Tensor,
    masks: torch.Tensor,
) -> torch.Tensor:
    """Perturbs the input tensor `x` using the provided masks.

    This method applies perturbations to the input `x` based on the given `masks`.
    The masks determine which features of the input should be perturbed (1 for
    perturbed, 0 for not perturbed). The generator is used to produce substitute
    values for the masked features.

    Args:
        x (torch.Tensor): The input tensor to be perturbed, with shape (1, f),
            where `f` is the number of features in the data.
        masks (torch.Tensor): A tensor of masks containing 0s and 1s to determine
            which features in `x` will be perturbed. Its shape is
            (b, *shape), where `shape` corresponds to the shape of `x`.

    Returns:
        torch.Tensor: A tensor containing the generated perturbed values,
        with shape (number_of_masks, *x.shape).
    """
    data_scaled = masks.unsqueeze(1)
    # Generate perturbed samples using the generator
    sampled_values_tensor = self.generator.generate(
        x, data_scaled
    )  # x and masks are passed to generate, shape (number_masks, *x.shape)

    return sampled_values_tensor
RandomSamplePerturbator
RandomSamplePerturbator(generator, max_batch_size=100)

Bases: Perturbator

A class to perturb tabular data using generated samples and binary masks.

Employs frequency-based sampling from training data distributions for mixed categorical/numerical features

Initializes the RandomSamplePerturbator with a generator.

Parameters:

Source code in muppet/components/perturbator/scale_feature_generator.py
def __init__(
    self,
    generator: RandomSampleTabularGenerator,
    max_batch_size: int = 100,
) -> None:
    """Initializes the RandomSamplePerturbator with a generator.

    Args:
        generator (RandomSampleTabularGenerator): An instance of
            RandomSampleTabularGenerator that will be used to generate random samples.
    """
    self.generator = generator
    super().__init__(max_batch_size=max_batch_size)
Functions
perturbate
perturbate(x, masks)

Perturb the input tensor using generated samples and binary masks.

Parameters:

  • x (Tensor) –

    The input tensor to be perturbed. Should have shape (1, num_features).

  • masks (Tensor) –

    A tensor containing binary masks with shape (number_masks, 1, num_features). Each mask is used to determine where to apply the perturbation.

Returns:

  • Tensor

    torch.Tensor: A tensor containing the perturbed samples with shape (number_masks, 1, num_features). Perturbations are applied according to the masks: where masks are 1, the value from the generated tensor is used; where masks are 0, the original value x is preserved.

Source code in muppet/components/perturbator/scale_feature_generator.py
def perturbate(
    self,
    x: torch.Tensor,
    masks: torch.Tensor,
) -> torch.Tensor:
    """Perturb the input tensor using generated samples and binary masks.

    Args:
        x (torch.Tensor): The input tensor to be perturbed. Should have shape (1, num_features).
        masks (torch.Tensor): A tensor containing binary masks with shape (number_masks, 1, num_features).
            Each mask is used to determine where to apply the perturbation.

    Returns:
        torch.Tensor: A tensor containing the perturbed samples with shape (number_masks, 1, num_features).
            Perturbations are applied according to the masks: where masks are 1, the value from
            the generated tensor is used; where masks are 0, the original value `x` is preserved.
    """
    self.generated_tensor = None
    n_samples, n_features = masks.shape
    masks = masks.unsqueeze(1)

    # Generate samples using the provided generator
    self.generated_tensor = self.generator.generate(n_samples).to(
        self.device
    )  # x and masks are passed to generate

    # Ensure masks and generated_tensor ar compatible

    # Calculate perturbations
    perturbations = (1 - masks) * self.generated_tensor + masks * x
    return perturbations

muppet.components.perturbator.simple

Simple perturbation strategies for basic masking operations.

Provides simple perturbator implementations for the MUPPET XAI framework, implementing the second step of the perturbation process. These modules apply straightforward transformations to input data based on binary masks from explorers.

Simple perturbators are fast, interpretable, and computationally efficient.

Classes:

  • SetToZeroPerturbator

    Masks features by setting them to zero.

  • BlurPerturbator

    Applies Gaussian blur to masked regions in image data, simulating information removal while maintaining spatial structure

These simple perturbators are building blocks for many explanation methods including RISE. They provide baseline perturbation strategies that can be compared against more sophisticated approaches or used when interpretability and speed are prioritized over realism.

Classes

SetToZeroPerturbator
SetToZeroPerturbator(max_batch_size=100)

Bases: Perturbator

Simple perturbator that sets masked features to zero.

Provides basic perturbation by multiplying input features with the complement of the mask (1-mask). This creates binary perturbations where features are either preserved or zeroed out.

Initialize the SetToZeroPerturbator.

Parameters:

  • max_batch_size (int, default: 100 ) –

    Maximum batch size for processing. Defaults to 100.

Source code in muppet/components/perturbator/simple.py
def __init__(self, max_batch_size: int = 100) -> None:
    """Initialize the SetToZeroPerturbator.

    Args:
        max_batch_size (int): Maximum batch size for processing. Defaults to 100.
    """
    super().__init__(max_batch_size=max_batch_size)
Functions
perturbate
perturbate(x, masks)

Set masked features to zero by multiplying with (1-mask).

Parameters:

  • x (Tensor) –

    The input example. Shape (b, *input_dims).

  • masks (Tensor) –

    The perturbation masks. Shape (N, *mask_shape).

Returns:

  • Tensor

    torch.Tensor: Perturbed examples with masked features set to zero.

Source code in muppet/components/perturbator/simple.py
def perturbate(
    self,
    x: torch.Tensor,
    masks: torch.Tensor,
) -> torch.Tensor:
    """Set masked features to zero by multiplying with (1-mask).

    Args:
        x (torch.Tensor): The input example. Shape (b, *input_dims).
        masks (torch.Tensor): The perturbation masks. Shape (N, *mask_shape).

    Returns:
        torch.Tensor: Perturbed examples with masked features set to zero.
    """
    # repeat the input example across N masks
    x_t = x.unsqueeze(dim=0).repeat(
        masks.size(0), *[1 for _ in x.shape]
    )  # (*x.shape) => (N, *x.shape)

    perturbations = x_t * (1 - masks)  # (N, *x.shape)

    return perturbations
BlurPerturbator
BlurPerturbator(
    add_noise=False,
    kernel_size=(11, 11),
    sigma=5,
    max_batch_size=100,
)

Bases: Perturbator

Perturbator that applies Gaussian blur to masked regions.

Creates perturbations by blurring masked areas instead of zeroing them. Maintains spatial information while reducing detail, useful for image explanations where complete occlusion is too harsh.

This perturbator simulates information removal while maintaining spatial structure by applying Gaussian blur to masked regions in image data. It provides a more realistic perturbation than simple masking for image explanations.

Initialize the BlurPerturbator.

Parameters:

  • add_noise (bool, default: False ) –

    Whether to add normal noise to perturbations. Defaults to False.

  • kernel_size (tuple[int, int], default: (11, 11) ) –

    Gaussian kernel size for blurring. Defaults to (11, 11).

  • sigma (int, default: 5 ) –

    Gaussian sigma parameter. Defaults to 5.

  • max_batch_size (int, default: 100 ) –

    Maximum batch size for processing. Defaults to 100.

Source code in muppet/components/perturbator/simple.py
def __init__(
    self,
    add_noise: bool = False,
    kernel_size: tuple[int, int] = (11, 11),
    sigma: int = 5,
    max_batch_size: int = 100,
) -> None:
    """Initialize the BlurPerturbator.

    Args:
        add_noise (bool): Whether to add normal noise to perturbations. Defaults to False.
        kernel_size (tuple[int, int]): Gaussian kernel size for blurring. Defaults to (11, 11).
        sigma (int): Gaussian sigma parameter. Defaults to 5.
        max_batch_size (int): Maximum batch size for processing. Defaults to 100.
    """
    self.add_noise = add_noise
    self.kernel_size = kernel_size
    self.sigma = sigma

    self.blurred_input = None
    super().__init__(max_batch_size=max_batch_size)
Functions
reinitialize
reinitialize()

Return BlurPerturbator to its original state.

Source code in muppet/components/perturbator/simple.py
def reinitialize(self):
    """Return BlurPerturbator to its original state."""
    self.blurred_input = None
perturbate
perturbate(x, masks)

Calculates the input perturbations by $Input*(1-Mask) + BluredInput*Mask$

Parameters:

  • x (Tensor) –

    The input examples. Shape (b=1, *x.shape[1:]) E.g (b=1, c, w, h)

  • masks (Tensor) –

    The generated masks to use for perturbing x. Shape (N, *mask_shape), len(mask_shape)==x.dim(). E.g mask_shape =(b=1, c=1, w, h) - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment, - c is the channel dimensions, - w is the width, - h is the height, - N the number of perturbation masks.

Returns:

  • Tensor

    x' (torch.Tensor): Perturbed version of x. Shape (N, *x.shape)

Source code in muppet/components/perturbator/simple.py
def perturbate(
    self,
    x: torch.Tensor,
    masks: torch.Tensor,
) -> torch.Tensor:
    """Calculates the input perturbations by `$Input*(1-Mask) + BluredInput*Mask$`

    Args:
        x (torch.Tensor): The input examples. Shape (b=1, *x.shape[1:]) E.g (b=1, c, w, h)

        masks (torch.Tensor): The generated masks to use for perturbing x.
            Shape (N, *mask_shape), len(mask_shape)==x.dim(). E.g mask_shape =(b=1, c=1, w, h)
            - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment,
            - c is the channel dimensions,
            - w is the width,
            - h is the height,
            - N the number of perturbation masks.

    Returns:
        x' (torch.Tensor): Perturbed version of x. Shape (N, *x.shape)

    """
    # blur the input example once for all
    if self.blurred_input is None:
        self.blurred_input = GaussianBlur(
            kernel_size=self.kernel_size, sigma=self.sigma
        )(x)  # E.g (b=1, c, w, h)

    N = masks.size(0)
    x = x.unsqueeze(dim=0).repeat(
        N, *[1 for _ in x.shape]
    )  # (*x.shape) => (N, *x.shape)
    x_b = self.blurred_input.unsqueeze(dim=0).repeat(
        N, *[1 for _ in self.blurred_input.shape]
    )  # (*x_b.shape) => (N, *x_b.shape)

    perturbations = x * (1 - masks) + x_b * masks  # (N, *x.shape)

    # add normal noise if requested
    if self.add_noise:
        noise = torch.randn(perturbations.shape)
        perturbations = perturbations + noise

    return perturbations

muppet.components.perturbator.timestep_generator

Time series perturbators using trainable generators for temporal explanations.

This module implements specialized perturbators for time series data that leverage trainable generators to create realistic temporal perturbations. These perturbators are essential for explaining sequential models where temporal dependencies and patterns are crucial for understanding model behavior.

In the MUPPET framework's perturbation step, these perturbators go beyond simple masking by using learned generative models to impute missing values in time series data. This preserves temporal coherence and realistic data characteristics, leading to more meaningful explanations for sequential models.

The module contains

GeneratorSamplingPerturbator: Uses trainable generators to impute values at multiple timesteps simultaneously, suitable for complex temporal patterns ConditionalSamplingGeneratorPerturbator: Performs conditional sampling for single-timestep perturbations, ideal for feature-specific temporal explanations

Key Technical Features
  • Automatic generator training with validation splits and early stopping
  • Temporal dependency preservation through learned representations
  • Feature-wise conditional sampling for fine-grained explanations
  • NaN handling for padding when model cannot handle unequal length signals
These perturbators are designed for explaining sequential models in domains like
  • Financial time series forecasting
  • Sensor data analysis and anomaly detection
  • Medical signal processing and diagnosis

The generators learn to capture complex temporal patterns during training, enabling realistic counterfactual scenarios that maintain temporal consistency while revealing feature importance across time.

Classes

GeneratorSamplingPertubator
GeneratorSamplingPertubator(
    generator, train_loader, max_batch_size=100
)

Bases: TrainablePerturbator

Perturbator using generative models for time series imputation.

Employs trainable generators to create realistic substitutes for perturbed time series segments. Learns temporal patterns from training data to generate contextually appropriate perturbations.

Perturbator that uses a generative model - generator - in order to impute the missing measurements.

Parameters:

  • generator (TrainableGenerator) –

    The generator to be used to impute the perturbed values through its inference method 'call'.

  • train_loader (DataLoader) –

    Training data.

  • max_batch_size (int, default: 100 ) –

    Max Batch size to use. Default to 100.

Source code in muppet/components/perturbator/timestep_generator.py
def __init__(
    self,
    generator: TrainableGenerator,
    train_loader: DataLoader,
    max_batch_size: int = 100,
) -> None:
    """Perturbator that uses a generative model - generator - in order to impute the missing measurements.

    Args:
        generator (TrainableGenerator): The generator to be used to impute the perturbed values through its inference method '__call__'.
        train_loader (DataLoader): Training data.
        max_batch_size (int): Max Batch size to use. Default to 100.
    """
    super().__init__(
        generator=generator,
        train_loader=train_loader,
        max_batch_size=max_batch_size,
    )
Functions
perturbate
perturbate(x, masks)

Perturbate the input x according to the masks. The generator will be used to perturbate each position covered by the mask.

Parameters:

  • x (Tensor) –

    The input example. Shape (b=1, f, t)

  • masks (Tensor) –

    Encodes what parts of x to be perturbed. Shape (N, x.shape)

Returns:

  • Tensor

    torch.Tensor: Perturbed version of x. (N, x.shape)

Source code in muppet/components/perturbator/timestep_generator.py
def perturbate(
    self,
    x: torch.Tensor,
    masks: torch.Tensor,
) -> torch.Tensor:
    """Perturbate the input x according to the masks. The generator will be
    used to perturbate each position covered by the mask.

    Args:
        x (torch.Tensor): The input example. Shape (b=1, f, t)
        masks (torch.Tensor): Encodes what parts of x to be perturbed. Shape (N, x.shape)

    Returns:
        torch.Tensor: Perturbed version of x. (N, x.shape)

    """
    N = masks.size(0)
    # match inputs and masks shape to multiply them
    x_t = x.unsqueeze(dim=0).repeat(
        N, *[1 for _ in x.shape]
    )  # (b=1, f, t) => (N, b=1, f, t)

    # the value 0 will be replaced/imputed and the rest stays as it is
    perturbations = x_t * (1 - masks)  # (N, x.shape)

    # release temp
    del x_t

    # loop through the received masks
    for idx, mask in enumerate(masks):
        # Iterate over all possible timesteps
        # Never perturb step 0, there is nothing to base the perturbation on

        for time_step in range(1, mask.shape[-1]):
            # Do not perturb if all values are 0 or NaN
            if (
                torch.equal(
                    mask[0, :, time_step],
                    torch.zeros_like(mask[0, :, time_step]),
                )
                or torch.isnan(mask[0, :, time_step]).all()
            ):
                continue

            features_to_perturb_at_this_step = torch.where(
                mask[0, :, time_step].flatten() == 1
            )[0].tolist()

            # get x0:t
            past = x[:, :, :time_step]
            # get x_at_t
            current = x[:, :, time_step]

            # impute the perturbed values using the generator
            sampled_values_at_time_step = self.generator.generate(
                past=past,
                current=current,
                features_to_perturb=features_to_perturb_at_this_step,
            )
            # Returns a vector of shape [nb_series], the same as current.
            # The values at the positions of features_to_not_perturb_at_this_step
            # should be the same as in current, and the rest should have been
            # perturbed.

            # Returns a vector of shape [nb_series], the same as current.
            # The values at the positions of features_to_not_perturb_at_this_step
            # should be the same as in current, and the rest should have been
            # perturbed.
            assert (
                perturbations[idx, :, :, time_step].shape == current.shape
            )
            assert sampled_values_at_time_step.shape == current.shape
            perturbations[idx, :, :, time_step] = (
                sampled_values_at_time_step
            )

        perturbations[idx, mask.isnan()] = float("nan")
    logger.debug(
        f"Calculated perturbations: {str([i for i in perturbations])}"
    )

    return perturbations  # (N, x.shape)
ConditionalSamplingGeneratorPertubator
ConditionalSamplingGeneratorPertubator(
    generator, train_loader, max_batch_size=100
)

Bases: TrainablePerturbator

Conditional perturbator for advanced time series explanations.

Uses conditional generators to create perturbations that respect feature dependencies and temporal relationships. Enables sophisticated perturbations for time series models.

Perturbator that uses a GAN model - generator - in order to impute the missing measurements.

Parameters:

  • generator (TrainableGenerator) –

    The generator to be used to impute the perturbed values through its inference method 'call'.

  • train_loader (DataLoader) –

    Training data.

  • max_batch_size (int, default: 100 ) –

    Max Batch size to use. Default to 100.

Source code in muppet/components/perturbator/timestep_generator.py
def __init__(
    self,
    generator: TrainableGenerator,
    train_loader: DataLoader,
    max_batch_size: int = 100,
) -> None:
    """Perturbator that uses a GAN model - generator - in order to impute the missing measurements.

    Args:
        generator (TrainableGenerator): The generator to be used to impute the perturbed values through its inference method '__call__'.
        train_loader (DataLoader): Training data.
        max_batch_size (int): Max Batch size to use. Default to 100.
    """
    super().__init__(
        generator=generator,
        train_loader=train_loader,
        max_batch_size=max_batch_size,
    )
Functions
perturbate
perturbate(x, masks)

Perturbate the input x according to the masks. The generator will be used to perturbate each position covered by the mask.

Parameters:

  • x (Tensor) –

    The input example. Shape (b=1, f, t)

  • masks (Tensor) –

    Encodes what parts of x to be perturbed. Shape (N, x.shape)

Returns:

  • Tensor

    torch.Tensor: Perturbed version of x. (N, x.shape)

Source code in muppet/components/perturbator/timestep_generator.py
def perturbate(
    self,
    x: torch.Tensor,
    masks: torch.Tensor,
) -> torch.Tensor:
    """Perturbate the input x according to the masks. The generator will be
    used to perturbate each position covered by the mask.

    Args:
        x (torch.Tensor): The input example. Shape (b=1, f, t)
        masks (torch.Tensor): Encodes what parts of x to be perturbed. Shape (N, x.shape)

    Returns:
        torch.Tensor: Perturbed version of x. (N, x.shape)

    """
    N = masks.size(0)
    # match inputs and masks shape to multiply them
    x_t = x.unsqueeze(dim=0).repeat(
        N, *[1 for _ in x.shape]
    )  # (b=1, f, t) => (N, b=1, f, t)

    # the value 0 will be replaced/imputed and the rest stays as it is
    perturbations = x_t * (1 - masks)  # (N, x.shape)

    # release temp
    del x_t

    # loop through the received masks
    for idx, mask in enumerate(masks):
        # get the time step to be perturbed
        _, feature, time_step = (
            int(el) for el in torch.where(mask == 1)
        )  # mask shape is (b=1, f, t)
        # get x0:t
        past = x[:, :, :time_step]
        # get x_at_t
        current = x[:, :, time_step]
        # impute the perturbed values using the generator
        sampled_values_at_time_step = self.generator.generate(
            past=past,
            current=current,
            features_to_perturb={feature},
            # perturb the current feature
        )

        assert torch.sum(perturbations[idx, :, feature, time_step]) == 0, (
            f"The perturbator is expecting all the values at time step={time_step} to be 0, but found {perturbations[idx, :, :, time_step]}"
        )
        # update the value at time_step=t from 0 to the sampled one
        perturbations[idx, :, :, time_step] = sampled_values_at_time_step
    logger.debug(
        f"Calculated perturbations: {str([i for i in perturbations])}"
    )

    return perturbations  # (N, x.shape)