Skip to content

Explorer

muppet.components.explorer.base

Base Explorer Component for MUPPET XAI Framework.

This module provides the abstract base class for all exploration strategies in the MUPPET framework. Explorers are the first component in the four-step perturbation-based XAI process, responsible for generating masks and exploration strategies that define how input data will be systematically perturbed to understand model behavior.

The Explorer component serves as foundation for all perturbation-based explanation methods by defining the logic for mask generation strategies. Different exploration approaches can be implemented by extending this base class with specific mask generation logic suitable for various data modalities (images, tabular data, time series) and explanation requirements.

Classes:

  • Explorer

    Abstract base class defining the exploration interface for generating perturbation premises. Provides iteration protocol, state management, and premise generation framework.

The four-step MUPPET process
  1. Explorer (this module): Generate masks and exploration strategies
  2. Perturbator: Apply masks to create perturbed inputs
  3. Attributor: Calculate feature scores from model predictions on perturbed data
  4. Aggregator: Combine attributions to produce final explanations
Note

All Explorer implementations must implement the get_premises_to_explore() method and properly manage the stop flag to indicate when exploration is complete. The explorer follows an iterator protocol and maintains state across exploration iterations through the current_iteration counter.

Classes

Explorer
Explorer(
    example=None, memory=None, model=None, device=None
)

Bases: ABC

Base class for exploration strategies in XAI perturbation-based methods.

An Explorer generates perturbation premises that define how input examples should be perturbed for explanation purposes. It follows the four-step perturbation approach: generate masks, apply masks, calculate scores, and aggregate attributions.

The Explorer serves as the foundation for all perturbation-based explanation methods by defining the interface for mask generation strategies. Different exploration approaches can be implemented by extending this base class with specific mask generation logic suitable for various data modalities (images, tabular data, time series) and explanation requirements.

Initialize the Explorer base class.

Parameters:

  • example (Tensor, default: None ) –

    The example to be explained. Will be set at runtime.

  • memory (Memory, default: None ) –

    The memory where premises are saved.

  • model (optional, default: None ) –

    The model being explained.

  • device (optional, default: None ) –

    The device to use. Will be updated from the main explainer.

Source code in muppet/components/explorer/base.py
def __init__(
    self, example=None, memory=None, model=None, device=None
) -> None:
    """Initialize the Explorer base class.

    Args:
        example (torch.Tensor, optional): The example to be explained. Will be set at runtime.
        memory (Memory, optional): The memory where premises are saved.
        model (optional): The model being explained.
        device (optional): The device to use. Will be updated from the main explainer.
    """
    self._stop = False
    self._current_iteration = 0
    self._premise_kwargs = {}

    # Initial arguments. These will usually be None
    self.example = example
    self.memory = memory
    self.model = model
    self.device = device

    super().__init__()
Attributes
stop property writable
stop

Get the exploration stop flag.

current_iteration property writable
current_iteration

Get the current iteration counter.

premise_kwargs property writable
premise_kwargs

Get the premise keyword arguments.

Functions
reinitialize
reinitialize()

Reset the explorer to its initial state.

Clears the stop flag and resets the iteration counter to prepare for a new exploration session.

Source code in muppet/components/explorer/base.py
def reinitialize(self):
    """Reset the explorer to its initial state.

    Clears the stop flag and resets the iteration counter to prepare
    for a new exploration session.
    """
    self.stop = False
    self.current_iteration = 0
next
next()

Generate the next batch of premises for exploration.

Increments the iteration counter, generates premises, and sets their device before returning them.

Returns:

  • Iterable[Premise]: The premises to explore in this iteration.

Source code in muppet/components/explorer/base.py
def next(self):
    """Generate the next batch of premises for exploration.

    Increments the iteration counter, generates premises, and sets
    their device before returning them.

    Returns:
        Iterable[Premise]: The premises to explore in this iteration.
    """
    # increase iteration
    self._current_iteration += 1
    premises_to_explore = self.get_premises_to_explore()

    # Set device for all premises
    for p in premises_to_explore:
        p.device = self.device

    return premises_to_explore
get_premises_to_explore abstractmethod
get_premises_to_explore()

The premises' generator. Creates the premises objects and sends them back to main explainer.

Returns:

  • Iterable[Premise]

    Iterable[Premise]: An iterable over the created premises.

Source code in muppet/components/explorer/base.py
@abstractmethod
def get_premises_to_explore(self) -> Iterable[Premise]:
    """The premises' generator. Creates the premises objects and sends them back to main explainer.

    Returns:
        Iterable[Premise]: An iterable over the created premises.

    """
    raise NotImplementedError

muppet.components.explorer.feature

Feature-based Explorer Components for Convolutional Neural Networks.

This module provides explorer implementations that leverage convolutional feature maps to generate perturbation masks in the MUPPET XAI framework. These explorers are specifically designed for explaining convolutional neural networks by using the spatial feature activations from the last convolutional layer to guide the perturbation process.

The feature-based exploration strategy generates masks based on individual feature maps (channels) from the final convolutional layer, creating targeted perturbations that help understand which spatial regions and features contribute most to model predictions. This approach is particularly effective for image classification tasks where spatial locality and feature hierarchy are important.

Classes:

  • CAMExplorer

    Generates one premise per feature map in the last convolutional layer, creating class activation map (CAM) style explanations through feature-guided perturbations.

The feature exploration process
  1. Extract Features: Hook into the last convolutional layer during forward pass
  2. Generate Premises: Create one premise per feature channel with spatial upsampling
  3. Enable Perturbation: Each premise contains upsampled feature activations as masks
  4. Single-shot Explanation: Completes exploration in one iteration

Classes

CAMExplorer
CAMExplorer(model)

Bases: Explorer

Class Activation Map (CAM) explorer for convolutional neural networks.

Generates perturbation premises based on feature maps from the last
convolutional layer. Each feature map is used to create a mask premise
for explaining the contribution of that particular feature to the model's
predictions.

This explorer leverages convolutional feature maps to generate perturbation masks
in the MUPPET XAI framework. It is specifically designed for explaining convolutional
neural networks by using the spatial feature activations from the last convolutional
layer to guide the perturbation process.

The feature-based exploration strategy generates masks based on individual feature maps
(channels) from the final convolutional layer, creating targeted perturbations that help
understand which spatial regions and features contribute most to model predictions. This
approach is particularly effective for image classification tasks where spatial locality
and feature hierarchy are important.
Technical Details
  • Works with any PyTorch model containing Conv2d layers
  • Automatically finds the last convolutional layer in the model architecture
  • Uses bilinear interpolation to upscale feature maps to input resolution
  • Generates exactly k premises where k is the number of output channels
  • Each premise represents one feature channel's spatial contribution
  • Suitable for models like ResNet, VGG, DenseNet, etc.
Note

This explorer requires the model to contain at least one Conv2d layer. The exploration completes in a single iteration, making it computationally efficient for generating pixel-importance-based explanations.

Initialize the CAM explorer.

Parameters:

  • model (Module) –

    The model to explain.

Source code in muppet/components/explorer/feature.py
def __init__(
    self,
    model: torch.nn.Module,
) -> None:
    """Initialize the CAM explorer.

    Args:
        model (torch.nn.Module): The model to explain.
    """
    self.example = None
    self.conv_layer_activation = None
    super().__init__(model=model)
Functions
get_premises_to_explore
get_premises_to_explore()

Generate a k * b number of premises where every one corresponds to the perturbation of the input example. Expects 4D input example. Shape (b=1, c, h, w).

where - k is the amount of features in last convolutional layer of the model, - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment, - c is the channel dimensions, - w is the width, - h is the height,

Returns:

  • List[ConvolutionalFeaturePremise]

    List[ConvolutionalFeaturesPremise]: Every premise includes the necessary information to generate its masks from the key attributes.

Source code in muppet/components/explorer/feature.py
def get_premises_to_explore(self) -> List[ConvolutionalFeaturePremise]:
    """Generate a `k * b` number of premises where every one corresponds to the perturbation of the input example.
    Expects 4D input example. Shape (b=1, c, h, w).

     where
        - k is the amount of features in last convolutional layer of the model,
        - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment,
        - c is the channel dimensions,
        - w is the width,
        - h is the height,

    Returns:
        List[ConvolutionalFeaturesPremise]: Every premise includes the necessary information to generate its masks from the key attributes.

    """
    upscaled_mask_shape = self.example.size()[-2:]

    # Finding last Conv2d layer of the model
    last_conv_layer = None

    for layer in reversed(list(self.model.modules())):
        if isinstance(layer, torch.nn.Conv2d):
            last_conv_layer = layer
            break

    if last_conv_layer is None:
        raise TypeError("given model has no Conv2d layer")

    # Placing a hook on the last Conv2d layer, to register layer activity during forward pass
    def hook(model, input, output):
        """Hook function to register convolutional layer activations.

        Args:
            model: The PyTorch model being hooked.
            input: Input tensor to the convolutional layer.
            output: Output tensor from the convolutional layer.
        """
        self.conv_layer_activation = output.detach()

    h = last_conv_layer.register_forward_hook(hook)

    # Forwarding input in model
    with torch.no_grad():
        self.model(self.example)

    # Removing hook
    h.remove()

    upsampled_activations = torch.nn.functional.interpolate(
        self.conv_layer_activation,
        size=upscaled_mask_shape,
        mode="bilinear",
    ).to(self.device)

    premises = []
    for channel in range(self.conv_layer_activation.size()[1]):
        premise = ConvolutionalFeaturePremise(
            key=(upsampled_activations, channel)
        )
        premises.append(premise)

    # tell the main explainer to stop the exploration
    self.stop = True

    return premises

muppet.components.explorer.gradient

Gradient-based Explorer Components for Optimization-driven exploration.

This module implements explorer components that use gradient descent optimization to iteratively refine perturbation masks in the MUPPET XAI framework. These explorers represent approaches where masks are not randomly generated but instead optimized through backpropagation to maximize or minimize specific attribution objectives.

The gradient-based exploration strategy starts with initialized mask parameters and iteratively updates them using gradient information from the model's predictions and attributed loss. This allows for finding optimal perturbation patterns that reveal the most informative aspects of the model's decision-making process.

These Explorers are aimed to be used in association with Differentiable attributors.

Classes:

  • GradientExplorer

    Base gradient-based explorer that optimizes mask grid using Adam optimizer over multiple iterations to find optimal perturbation grid, the grid is upscaled to the input image shape for perturbation.

  • GradientCAMExplorer

    Extension of GradientExplorer that incorporates Class Activation Maps (CAM) from convolutional layers to guide the optimization process with spatial feature information. Weights associated to each feature maps are optimized over the iterations.

The gradient exploration process
  1. Initialize: Create learnable mask parameters with random or zero initialization
  2. Forward: Generate premises with current mask parameters
  3. Backward: Compute gradients from attribution scores via backpropagation
  4. Optimize: Update mask parameters using gradient descent (Adam optimizer)
  5. Iterate: Repeat until convergence or maximum iterations reached
  6. Clamp: Normalize mask values to [0,1] range after each update

Classes

GradientExplorer
GradientExplorer(
    max_iter=100,
    lr=0.2,
    mask_shape=(28, 28),
    premise_class=GradientPremise,
    nb_premises_at_startup=1,
)

Bases: Explorer

Gradient-based explorer for iterative mask grid optimization.

This explorer implements optimization-driven grid exploration using an Adam optimizer to iteratively refine perturbation masks. The perturbation mask is derived from the grid by upscaling it to the input image shape.

The gradient-based exploration strategy starts with initialized grid parameters and iteratively updates them using gradient information from the model's predictions and attribution loss. This allows for finding optimal perturbation patterns that reveal the most informative aspects of the model's decision-making process.

Technical Details
  • Uses Adam optimizer for stable and efficient mask parameter updates
  • Supports both spatial masks (for images) and feature masks (for tabular data)
  • Gradient flow enabled through premise.attribution.backward() calls
  • Automatic gradient accumulation and parameter clamping to valid ranges
  • Reinitializable for multiple explanation sessions
Optimization Strategy

The gradient-based approach transforms the explanation problem into an optimization objective where the goal is to find mask parameters that maximize the loss usually derived from the model predictions of the perturbed inputs in order to reveal the most influential input regions or features. This behaviour is controled by the chosen attribution loss (see Differentiable attributors).

Initialize the gradient-based explorer.

Parameters:

  • max_iter (int, default: 100 ) –

    The number of iterations for optimization. Defaults to 100.

  • lr (float, default: 0.2 ) –

    Learning rate for Adam optimizer. Defaults to 0.2.

  • mask_shape (tuple, default: (28, 28) ) –

    The learning mask shape. Defaults to (28, 28).

  • premise_class (Type[Premise], default: GradientPremise ) –

    Premise class to create. Defaults to GradientPremise.

  • nb_premises_at_startup (int, default: 1 ) –

    Number of premises at startup. Defaults to 1.

Source code in muppet/components/explorer/gradient.py
def __init__(
    self,
    max_iter: int = 100,
    lr: float = 0.2,
    mask_shape: tuple = (28, 28),
    premise_class: Type[Premise] = GradientPremise,
    nb_premises_at_startup: int = 1,
) -> None:
    """Initialize the gradient-based explorer.

    Args:
        max_iter (int, optional): The number of iterations for optimization. Defaults to 100.
        lr (float, optional): Learning rate for Adam optimizer. Defaults to 0.2.
        mask_shape (tuple, optional): The learning mask shape. Defaults to (28, 28).
        premise_class (Type[Premise]): Premise class to create. Defaults to GradientPremise.
        nb_premises_at_startup (int, optional): Number of premises at startup. Defaults to 1.
    """
    self.max_iter = max_iter
    self.lr = lr
    self.learning_mask_shape = mask_shape
    self.premise_class = premise_class
    self.nb_premises_at_startup = nb_premises_at_startup

    self.optimizers = []

    super().__init__()
Functions
get_premises_to_explore
get_premises_to_explore()

Responsible for, at first, initializing the premises and at every subsequent call, calculating the gradients and doing one optimization step associated to each premise.

The memory is list of premises corresponding to the number of input examples.

Returns:

  • List[GradientPremise]

    List[GradientPremise]: List of created/updated premises.

Source code in muppet/components/explorer/gradient.py
def get_premises_to_explore(self) -> List[GradientPremise]:
    """Responsible for, at first, initializing the premises and at every subsequent call, calculating the gradients and
        doing one optimization step associated to each premise.

    The memory is list of premises corresponding to the number of input examples.

    Returns:
        List[GradientPremise]: List of created/updated premises.

    """
    # get premises
    premises = self.memory.get_premises()
    key_shape = (1, *self.learning_mask_shape)  # (1, w, h)

    # at first iteration when memory is still empty, generate the premises
    if len(premises) == 0:
        for _ in range(self.nb_premises_at_startup):
            key = torch.zeros(
                key_shape,
                dtype=torch.float32,
                requires_grad=True,
                device=self.device,
            )

            premise = self.premise_class(
                key=key,
                upscaled_mask_shape=self.example.shape[2:],
                **self.premise_kwargs,
            )
            premises.append(premise)
            self.optimizers.append(
                torch.optim.Adam([premise.key], lr=self.lr)
            )

    # subsequent calls, update the premise
    else:
        for premise, optimizer in zip(premises, self.optimizers):
            optimizer.zero_grad()
            premise.attribution.backward(
                retain_graph=True
            )  # retain_graph allows optimization on multiple premises which still share some elements of the compute graph (ie. the model)
            optimizer.step()

            # normalize the learning mask
            premise.key.data.clamp_(0, 1)

    # tell the main explainer to stop the exploration
    if self.current_iteration == self.max_iter:
        self.stop = True

    return premises
reinitialize
reinitialize()

Reset the gradient explorer to its initial state.

Clears the list of optimizers and calls the parent reinitialize method.

Source code in muppet/components/explorer/gradient.py
def reinitialize(self):
    """Reset the gradient explorer to its initial state.

    Clears the list of optimizers and calls the parent reinitialize method.
    """
    self.optimizers = []
    return super().reinitialize()
GradientCAMExplorer
GradientCAMExplorer(
    max_iter=100,
    lr=0.2,
    mask_shape=(28, 28),
    premise_class=GradientPremise,
    nb_premises_at_startup=1,
)

Bases: GradientExplorer

Gradient explorer with Class Activation Maps integration.

Extends GradientExplorer by incorporating CAM (Class Activation Maps) from convolutional layers to guide mask optimization with spatial feature information. The CAM maps are stored in premise_kwargs and can be accessed by premises after creation. Here the optimized parameters are the weights associated to each CAM maps.

Technical Details
  • CAM variant extracts feature maps for initialization and guidance
  • Automatic gradient accumulation and parameter clamping to valid ranges
  • Reinitializable for multiple explanation sessions

Initialize the GradientCAMExplorer with Class Activation Maps integration.

Extends GradientExplorer by incorporating CAM (Class Activation Maps) from convolutional layers to guide mask optimization with spatial feature information.

Parameters:

  • max_iter (int, default: 100 ) –

    Number of optimization iterations for the Adam optimizer.

  • lr (float, default: 0.2 ) –

    Learning rate for Adam optimizer.

  • mask_shape (tuple[int, int], default: (28, 28) ) –

    Shape of the learnable mask parameters.

  • premise_class (Type[Premise], default: GradientPremise ) –

    Class of premises to create, should have a backward attribute (usually a torch function computation for auto-différentiability) (default: GradientPremise).

  • nb_premises_at_startup (int, default: 1 ) –

    Number of premises to generate at initialization.

Source code in muppet/components/explorer/gradient.py
def __init__(
    self,
    max_iter: int = 100,
    lr: float = 0.2,
    mask_shape: tuple[int, int] = (28, 28),
    premise_class: Type[Premise] = GradientPremise,
    nb_premises_at_startup: int = 1,
) -> None:
    """Initialize the GradientCAMExplorer with Class Activation Maps integration.

    Extends GradientExplorer by incorporating CAM (Class Activation Maps) from
    convolutional layers to guide mask optimization with spatial feature information.

    Args:
        max_iter (int): Number of optimization iterations for the Adam optimizer.
        lr (float): Learning rate for Adam optimizer.
        mask_shape (tuple[int, int]): Shape of the learnable mask parameters.
        premise_class (Type[Premise]): Class of premises to create,
            should have a backward attribute (usually a torch function computation
            for auto-différentiability) (default: GradientPremise).
        nb_premises_at_startup (int): Number of premises to generate at initialization.
    """
    self.cam_maps_were_obtained = False
    super().__init__(
        max_iter, lr, mask_shape, premise_class, nb_premises_at_startup
    )
Functions
get_cam_maps
get_cam_maps()

Extract CAM (Class Activation Maps) from the model.

Registers a forward hook on the last convolutional layer to capture activations and stores them in premise_kwargs for use by premises.

Source code in muppet/components/explorer/gradient.py
def get_cam_maps(self):
    """Extract CAM (Class Activation Maps) from the model.

    Registers a forward hook on the last convolutional layer to capture
    activations and stores them in premise_kwargs for use by premises.
    """
    if self.cam_maps_were_obtained is False:
        last_conv_layer = None

        for layer in reversed(list(self.model.modules())):
            if isinstance(layer, torch.nn.Conv2d):
                last_conv_layer = layer
                break

        if last_conv_layer is None:
            raise TypeError("given model has no Conv2d layer")

        # Placing a hook on the last Conv2d layer, to register layer activity during forward pass
        def hook(model, input, output):
            """Hook function to register layer activations during forward pass.

            Args:
                model: The PyTorch model being hooked.
                input: Input tensor to the layer.
                output: Output tensor from the layer.
            """
            self.premise_kwargs["activations"] = output.detach()

        h = last_conv_layer.register_forward_hook(hook)

        # Forwarding input in model
        with torch.no_grad():
            self.model(self.example)

        # Removing hook
        h.remove()

        self.learning_mask_shape = (
            self.premise_kwargs["activations"].size()[1],
        )  # tuple

    self.cam_maps_were_obtained = True
reinitialize
reinitialize()

Reset the GradientCAM explorer to its initial state.

Resets the CAM maps flag and calls the parent reinitialize method.

Source code in muppet/components/explorer/gradient.py
def reinitialize(self):
    """Reset the GradientCAM explorer to its initial state.

    Resets the CAM maps flag and calls the parent reinitialize method.
    """
    self.cam_maps_were_obtained = False
    return super().reinitialize()
get_premises_to_explore
get_premises_to_explore()

Get premises with CAM maps included.

First extracts CAM maps, then calls the parent method to generate premises.

Returns:

  • List[GradientPremise]

    List[GradientPremise]: List of premises with CAM data available.

Source code in muppet/components/explorer/gradient.py
def get_premises_to_explore(self) -> List[GradientPremise]:
    """Get premises with CAM maps included.

    First extracts CAM maps, then calls the parent method to generate premises.

    Returns:
        List[GradientPremise]: List of premises with CAM data available.
    """
    self.get_cam_maps()
    return super().get_premises_to_explore()

muppet.components.explorer.mask

Mask-based Explorer Components for Perturbation Strategies.

This module provides a set of explorer implementations that generate various types of perturbation masks for the MUPPET XAI framework. These explorers implement different sampling and masking strategies to systematically probe model behavior through diverse perturbation patterns, supporting multiple data modalities and explanation approaches.

The mask-based exploration strategies generate sets of perturbation masks using random sampling, segmentation or distribution-based methods. These approaches provide broad coverage of the input space to understand model sensitivity across different regions or features.

Classes:

  • RandomMasksExplorer

    Generates random binary masks for spatial image perturbations with configurable mask density and grid resolution.

  • SegmentedBinaryRandomMasksExplorer

    Creates segment-based random masks using superpixel segmentation (SLIC) for semantically meaningful image regions.

  • RandomNormalExplorer

    Generates masks from normal distribution for continuous perturbations, particularly suitable for tabular data.

  • BinaryFeaturePermutationsExplorer

    Enumerates all possible binary feature combinations (coalitions) for exhaustive tabular data analysis.

The mask exploration process
  1. Generate: Create mask patterns based on the specific strategy
  2. Sample: Apply probabilistic or deterministic sampling rules
  3. Package: Wrap masks in premises with necessary metadata
  4. Batch: Return complete set of masks for single-iteration exploration
Technical Details
  • Random Masks: Generate binary masks with configurable grid size and density
  • Segmented Masks: Use SLIC superpixel segmentation for semantic coherence
  • Normal Masks: Sample from Gaussian distribution for continuous perturbations
  • Permutation Masks: Enumerate all 2^n feature combinations with limits
  • Reproducible: Seed-based random generation for consistent results

Classes

RandomMasksExplorer
RandomMasksExplorer(
    nmasks=800, mask_dim=7, mask_proba=0.5, seed=None
)

Bases: Explorer

Random mask explorer for general perturbation-based explanations.

Generates a specified number of random binary masks for perturbing input examples. Each mask is devrived by upscaling a randomly sampled binary grid.

Technical Details
  • Random Masks: Generate binary masks with configurable grid size and density
  • Reproducible: Seed-based random generation for consistent results

Initialize the RandomMasksExplorer.

Parameters:

  • nmasks (int, default: 800 ) –

    Number of random masks to generate. Defaults to 800.

  • mask_dim (int | tuple[int, int], default: 7 ) –

    Grid size for base mask. Defaults to 7.

  • mask_proba (float, default: 0.5 ) –

    Probability of masking each cell. Defaults to 0.5.

  • seed (int | None, default: None ) –

    Random seed for reproducibility. Defaults to None.

Source code in muppet/components/explorer/mask.py
def __init__(
    self,
    nmasks: int = 800,
    mask_dim: int | tuple[int, int] = 7,
    mask_proba: float = 0.5,
    seed: int | None = None,
) -> None:
    """Initialize the RandomMasksExplorer.

    Args:
        nmasks (int): Number of random masks to generate. Defaults to 800.
        mask_dim (int | tuple[int, int]): Grid size for base mask. Defaults to 7.
        mask_proba (float): Probability of masking each cell. Defaults to 0.5.
        seed (int | None): Random seed for reproducibility. Defaults to None.
    """
    self.nmasks = nmasks
    self.mask_dim = mask_dim
    self.mask_proba = mask_proba
    self.seed = seed

    super().__init__()
Functions
get_premises_to_explore
get_premises_to_explore()

Generate a nmasks number of premises where every one corresponds to the perturbation of the input example. Expects 4D input example. Shape (b=1, **, h, w).

** can be anything, either no dimension or an arbitrary number of dimensions. Usually it will be 1 dimension, the channel, so shape is (b=1,c,h,w).

where - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment, - c is the channel dimensions, - w is the width, - h is the height,

Returns:

  • List[BinaryRandomPremise]

    List[BinaryRandomPremise]: Every premise includes the necessary information to generate its random mask from the key attribute.

Source code in muppet/components/explorer/mask.py
def get_premises_to_explore(self) -> List[BinaryRandomPremise]:
    """Generate a `nmasks` number of premises where every one corresponds to the perturbation of the input example.
    Expects 4D input example. Shape (b=1, **, h, w).

    ** can be anything, either no dimension or an arbitrary number of dimensions. Usually it will be 1 dimension, the channel, so shape
    is (b=1,c,h,w).

    where
        - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment,
        - c is the channel dimensions,
        - w is the width,
        - h is the height,

    Returns:
            List[BinaryRandomPremise]: Every premise includes the necessary information to generate its random mask from the key attribute.
    """
    mask_shape = (self.example.shape[-2], self.example.shape[-1])

    premises = []
    for i in range(self.nmasks):
        seed = self.seed
        if seed is not None:
            seed = seed + i
        premise = BinaryRandomPremise(
            key=(self.mask_dim, self.mask_proba, mask_shape),
            seed=seed,
            **self.premise_kwargs,
        )
        premises.append(premise)

    # tell the main explainer to stop the exploration
    self.stop = True

    return premises
SegmentedBinaryRandomMasksExplorer
SegmentedBinaryRandomMasksExplorer(
    nmasks=500, masked_proba=0.5, n_segments=100
)

Bases: Explorer

Segmented random mask explorer for image-based explanations.

Generates random binary masks based on image segmentation. Uses superpixel segmentation to create meaningful perturbation regions, ensuring that semantically coherent areas are masked together.

This explorer uses superpixel segmentation (SLIC) to create semantically meaningful image regions for perturbation. Instead of using random pixel-based masks, it preserves object boundaries and creates coherent masked regions that respect the natural structure of the image.

Technical Details
  • Segmented Masks: Use SLIC superpixel segmentation for semantic coherence
  • Spatial coherence: Segmentation-based masks preserve object boundaries
  • Semantic Meaningful: Masks respect natural image structure

Initialize the SegmentedBinaryRandomMasksExplorer.

Parameters:

  • nmasks (int, default: 500 ) –

    Number of random masks to generate. Defaults to 500.

  • masked_proba (float, default: 0.5 ) –

    Probability of masking each superpixel. Defaults to 0.5.

  • n_segments (int, default: 100 ) –

    Approximate number of superpixels. Defaults to 100.

Source code in muppet/components/explorer/mask.py
def __init__(
    self,
    nmasks: int = 500,
    masked_proba: float = 0.5,
    n_segments: int = 100,
) -> None:
    """Initialize the SegmentedBinaryRandomMasksExplorer.

    Args:
        nmasks (int): Number of random masks to generate. Defaults to 500.
        masked_proba (float): Probability of masking each superpixel. Defaults to 0.5.
        n_segments (int): Approximate number of superpixels. Defaults to 100.
    """
    self.nmasks = nmasks
    self.mask_proba = masked_proba
    self.n_segments = n_segments

    super().__init__()
Functions
get_premises_to_explore
get_premises_to_explore()

Generate a nmasks number of premises where every one corresponds to the perturbation of the input example. Expects 4D input example. Shape (b=1, c, h, w).

where - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment, - c is the channel dimensions, - w is the width, - h is the height,

Returns:

  • List[SegmentedBinaryImagePremise]

    List[BinaryRandomPremise]: Every premise includes the necessary information to generate its random mask from the key attribute.

Source code in muppet/components/explorer/mask.py
def get_premises_to_explore(self) -> List[SegmentedBinaryImagePremise]:
    """Generate a `nmasks` number of premises where every one corresponds to the perturbation of the input example.
    Expects 4D input example. Shape (b=1, c, h, w).

    where
        - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment,
        - c is the channel dimensions,
        - w is the width,
        - h is the height,

    Returns:
        List[BinaryRandomPremise]: Every premise includes the necessary information to generate its random mask from the key attribute.

    """
    segmented_example = self.get_segmented_tensor_from_example()

    # Add the segmented example to the premise_kwargs ; this way it will be
    # passed to every Premise at its creation and memorized by them as an attribute
    self.premise_kwargs["segmented_example"] = segmented_example

    premises = []
    binary_matrix = (
        torch.rand(self.nmasks, segmented_example.shape[0])
        < self.mask_proba
    ).long()  # shape of binary_matrix = (nmasks, s)

    for i in range(self.nmasks):
        binary_vector = binary_matrix[i]
        premise = SegmentedBinaryImagePremise(
            key=binary_vector,
            **self.premise_kwargs,
        )
        premises.append(premise)

    # tell the main explainer to stop the exploration
    self.stop = True

    return premises
get_segmented_tensor_from_example
get_segmented_tensor_from_example()

Static method to get the segmented tensor from an array of labels.

Parameters:

  • example (Tensor) –

    Image to explain with shape (b=1, c, h, w).

  • n_segments (int) –

    The (approximate) number of labels in the segmented output image.

Returns:

  • tensor

    segmented example, tensor of shape (approaches value of n_segments, h, w) in which each slice of shape (h, w) contains 1 in the area of the superpixel and 0 elsewhere.

Source code in muppet/components/explorer/mask.py
def get_segmented_tensor_from_example(self):
    """Static method to get the segmented tensor from an array of labels.

    Args:
        example (torch.Tensor): Image to explain with shape (b=1, c, h, w).
        n_segments(int): The (approximate) number of labels in the segmented output image.

    Returns:
        tensor: segmented example, tensor of shape (approaches value of n_segments, h, w) in which each slice of shape (h, w)
            contains 1 in the area of the superpixel and 0 elsewhere.
    """
    _, _, h, w = self.example.shape

    example_np = self.example[0].permute(1, 2, 0).cpu()
    labels = torch.from_numpy(
        slic(
            example_np,
            n_segments=self.n_segments,
        )
    )

    unique_labels = torch.unique(labels)
    segmented_example = labels.unsqueeze(dim=0).repeat(
        len(unique_labels), 1, 1
    )
    indexes = (
        torch.arange(1, len(unique_labels) + 1)
        .unsqueeze(dim=1)
        .unsqueeze(dim=2)
        .repeat(1, h, w)
    ).cpu()

    # Compare each element of segmented_example with indexes
    # If they match, set the value to 1; otherwise, set it to 0
    segmented_example = torch.where(segmented_example == indexes, 1, 0)

    return segmented_example
RandomNormalExplorer
RandomNormalExplorer(nmasks=800, seed=1)

Bases: Explorer

Random normal distribution explorer for continuous perturbations.

Generates perturbation premises using random values sampled from a normal distribution. Provides continuous perturbations instead of binary masks, useful for tabular data where smooth noise-based modifications are preferred over binary masking.

Technical Details
  • Normal Masks: Sample from Gaussian distribution for continuous perturbations
  • Continuous perturbation: Smooth noise-based modifications
  • Reproducible: Seed-based sampling for consistent results
Note

This explorer is ideal for tabular data or scenarios where continuous perturbations are more appropriate than binary masking strategies.

Initialize the RandomNormalExplorer.

Parameters:

  • nmasks (int, default: 800 ) –

    Number of random masks to generate. Defaults to 800.

  • seed (int, default: 1 ) –

    Random seed for reproducibility. Defaults to 1.

Source code in muppet/components/explorer/mask.py
def __init__(self, nmasks: int = 800, seed: int = 1) -> None:
    """Initialize the RandomNormalExplorer.

    Args:
        nmasks (int): Number of random masks to generate. Defaults to 800.
        seed (int): Random seed for reproducibility. Defaults to 1.
    """
    self.nmasks = nmasks
    self.seed = seed
    self.stop = False
    self.current_iteration = 0
    super().__init__()
Functions
get_random_normal_key
get_random_normal_key(seed)

Generates a random vector (key) based on a normal distribution centered at zero. This vector is used as a key for the premise.

Parameters:

  • seed (int) –

    Seed for random number generation to ensure reproducibility.

Returns:

  • Tensor

    torch.Tensor: A random mask (key) generated from the normal distribution.

  • Tensor

    Shape is (nmasks, *x.shape), where x.shape is the input feature dimensions.

Source code in muppet/components/explorer/mask.py
def get_random_normal_key(self, seed: int) -> torch.Tensor:
    """Generates a random vector (key) based on a normal distribution centered at zero.
    This vector is used as a key for the premise.

    Args:
        seed (int): Seed for random number generation to ensure reproducibility.

    Returns:
        torch.Tensor: A random mask (key) generated from the normal distribution.
        Shape is (nmasks, *x.shape), where `x.shape` is the input feature dimensions.
    """
    torch.manual_seed(seed)  # Set seed for reproducibility
    mask_shape = (
        self.example.shape[1],
    )  # Shape based on input feature size
    key = torch.randn(
        mask_shape, device=self.device
    )  # Random key generated
    return key
get_premises_to_explore
get_premises_to_explore()

Generates the list of premises to be explored. Each premise is created by generating a random key for perturbation using a unique seed.

Returns:

  • List[KeyBasedMaskPremise]

    List[KeyBasedMaskPremise]: A list of premises to explore. Each premise contains a key and seed.

Source code in muppet/components/explorer/mask.py
def get_premises_to_explore(self) -> List[KeyBasedMaskPremise]:
    """Generates the list of premises to be explored. Each premise is created by generating
    a random key for perturbation using a unique seed.

    Returns:
        List[KeyBasedMaskPremise]: A list of premises to explore. Each premise contains a key and seed.
    """
    premises = []
    for i in range(self.nmasks):
        current_seed = self.seed + i  # Ensure each seed is unique
        key = self.get_random_normal_key(current_seed)  # Generate a key
        premise = KeyBasedMaskPremise(
            key=key, seed=current_seed
        )  # Create a premise with the key and seed
        premises.append(premise)

    # Indicate that exploration is complete
    self.stop = True

    return premises
BinaryFeaturePermutationsExplorer
BinaryFeaturePermutationsExplorer(
    n_repeats=1, seed=None, max_permutations=900
)

Bases: Explorer

Binary feature permutation explorer for combinatorial explanations.

Generates all possible binary feature combinations (coalitions) for systematic exploration of feature interactions. Useful for kernel-SHAP-like explanations where all feature subsets should to be evaluated.

This explorer enumerates all possible binary feature combinations (coalitions) for exhaustive tabular data analysis. It systematically explores every possible subset of features to understand their individual and collective contributions to model predictions.

Technical Details
  • Permutation Masks: Enumerate all 2^n feature combinations with configurable limit, permutation are randomly sampled if the possible number of permutation exceed the limit.
Note

This explorer is ideal for tabular data with manageable feature counts. For datasets with many features, the number of combinations grows exponentially (2^n), so max_permutations helps limit computational cost.

Initialize the BinaryFeaturePermutationsExplorer.

Parameters:

  • n_repeats (int, default: 1 ) –

    Number of times to repeat each permutation. Defaults to 1.

  • seed (int | None, default: None ) –

    Random seed for reproducibility. Defaults to None.

  • max_permutations (int, default: 900 ) –

    Maximum number of permutations. Defaults to 900.

Source code in muppet/components/explorer/mask.py
def __init__(
    self,
    n_repeats: int = 1,
    seed: int | None = None,
    max_permutations: int = 900,
) -> None:
    """Initialize the BinaryFeaturePermutationsExplorer.

    Args:
        n_repeats (int): Number of times to repeat each permutation. Defaults to 1.
        seed (int | None): Random seed for reproducibility. Defaults to None.
        max_permutations (int): Maximum number of permutations. Defaults to 900.
    """
    self.n_repeats = n_repeats
    self.seed = seed
    self.max_permutations = max_permutations

    super().__init__()
Functions
get_premises_to_explore
get_premises_to_explore()

Generates all possible binary feature permutations for the given number of features. Limits the total number of permutations to max_permutations.

Returns:

  • List[KeyBasedMaskPremise]

    List[BinaryPremise]: Each premise includes the necessary information (binary mask)

  • List[KeyBasedMaskPremise]

    to perform perturbations on the input example.

Source code in muppet/components/explorer/mask.py
def get_premises_to_explore(self) -> List[KeyBasedMaskPremise]:
    """Generates all possible binary feature permutations for the given number of features.
    Limits the total number of permutations to `max_permutations`.

    Returns:
        List[BinaryPremise]: Each premise includes the necessary information (binary mask)
        to perform perturbations on the input example.
    """
    # Generate the permutations as PyTorch tensors
    perm_tensors = self._generate_permutations()

    # Total number of possible permutations without repetitions
    num_permutations = len(perm_tensors)

    if num_permutations >= self.max_permutations:
        # Case 1: The number of available permutations is sufficient
        random.seed(self.seed)
        perm_tensors = random.sample(perm_tensors, self.max_permutations)
    else:
        # Case 2: The number of permutations is insufficient, supplement with repeats
        total_needed = self.max_permutations
        perm_tensors = perm_tensors * self.n_repeats  # Apply the repeats

        # Select a subset if necessary after supplementing with repeats
        if len(perm_tensors) > total_needed:
            perm_tensors = perm_tensors[:total_needed]

    # Initialize the list of premises
    premises = []
    for i, perm_tensor in enumerate(perm_tensors):
        seed = self.seed
        if seed is not None:
            seed = seed + i

        # Create a BinaryPremise with the generated binary mask as a PyTorch tensor
        premise = KeyBasedMaskPremise(
            key=perm_tensor,
            seed=seed,
            **self.premise_kwargs,  # Additional arguments specific to the premise
        )
        premises.append(premise)

    # Signal that the exploration is complete
    self.stop = True

    return premises

muppet.components.explorer.timestep

Timestep-based Explorer Components for Sequential Data Explanations.

Provides explorer implementations for explaining sequential and time series models in the MUPPET XAI framework. Focuses on temporal relationships and feature interactions across time steps using perturbation masks.

Strategy: Systematically perturbs (timestep, feature) pairs to assess individual contributions in sequential models (e.g., RNNs, Transformers).

Classes:

  • RepeatedTimestepExplorer

    Generates premises for each (timestep, feature) combination using Monte Carlo sampling (num_sampling).

Technical Summary
  • Explores timesteps \(t \in [1, ext{signal\_length}-1]\). Timestep 0 is excluded.
  • Total Premises \(= ( ext{T}-1) imes ext{F} imes ext{S}\), where $ ext{S} = ext{num_sampling}$.
  • Input Format Assumption: (batch, features, timesteps).

Classes

RepeatedTimestepExplorer
RepeatedTimestepExplorer(num_sampling=100)

Bases: Explorer

Timestep explorer for temporal sequence explanations.

Generates premises for explaining time series by perturbing different (timestep, feature) pairs. Uses Monte Carlo sampling ($ ext{num_sampling}$) for statistical robustness.

Crucial for understanding how sequential models (RNNs, Transformers) rely on temporal dependencies and feature interactions over time.

Technical Details
  • Coverage: Timesteps \(t \in [1, ext{signal\_length}-1]\) and all features.
  • Total number of generated premises: \(( ext{T}-1) imes ext{F} imes ext{num\_sampling}\).

Initialize the RepeatedTimestepExplorer.

Parameters:

  • num_sampling (int, default: 100 ) –

    Number of Monte Carlo samples per timestep-feature pair. Defaults to 100.

Source code in muppet/components/explorer/timestep.py
def __init__(
    self,
    num_sampling: int = 100,
) -> None:
    """Initialize the RepeatedTimestepExplorer.

    Args:
        num_sampling (int): Number of Monte Carlo samples per timestep-feature pair. Defaults to 100.
    """
    self.num_sampling = num_sampling
    super().__init__()
Functions
get_premises_to_explore
get_premises_to_explore()

Generates all TimeStepPremise objects for exploration in a single pass.

The premises cover all combinations of timesteps \(t \in [1, ext{signal\_length}-1]\) and features, each repeated $ ext{num_sampling}$ times.

Returns:

Source code in muppet/components/explorer/timestep.py
def get_premises_to_explore(self) -> List[TimeStepPremise]:
    """Generates all TimeStepPremise objects for exploration in a single pass.

    The premises cover all combinations of timesteps $t \in [1, \text{signal\_length}-1]$
    and features, each repeated $\text{num\_sampling}$ times.

    Returns:
        List[TimeStepPremise]: List of premises.
    """
    feature_size = self.example.shape[1]
    signal_length = self.example.shape[2]
    premises = []
    # Iterate over timesteps, features, and samples
    for i, j, _ in itertools.product(
        range(signal_length - 1, 0, -1),  # Timesteps 1 to T-1
        set(range(self.example.shape[1])),  # Features
        range(self.num_sampling),  # Samples
    ):
        # TODO enventually other exploring strategy for masks
        key = (
            {"timestep": i, "feature": j},
            (feature_size, signal_length),
        )
        premises.append(TimeStepPremise(key=key, **self.premise_kwargs))

    # Stop exploration after this single pass
    self.stop = True

    return premises