Memory
muppet.components.memory.base
Base Memory Components for MUPPET XAI Framework.
Provides the abstract base classes for memory management in the MUPPET XAI framework. Memory components store and manage premises, which represent perturbation data (masks, keys, attribution results) throughout the explanation process.
Memory bridges the four-step process (exploration, perturbation, attribution, aggregation) by maintaining state via Premise objects.
Classes:
-
Premise–Abstract base class for a single perturbation point. Manages key, mask, attribution, and provides lazy evaluation with caching.
-
Memory–Abstract base class defining the interface for storing and retrieving Premise collections.
-
PremiseList–Simple, list-based implementation of Memory.
Mask Convention (Consistent Binary Format): - 0: Preserve the original input value (no perturbation) - 1: Perturb the input value (apply perturbation strategy)
Technical Note
Premises implement key-changed detection to trigger dynamic recomputation of cached masks and explanations, supporting both static and trainable methods.
Classes
Premise
Bases: ABC
Abstract base class for a single perturbation premise.
A premise is the fundamental unit of exploration, containing the key for deterministic mask generation, the mask itself, and attribution results. It utilizes lazy evaluation and caching for efficiency.
The key-changed mechanism ensures recomputation when parameters are modified. Mask Convention: 0 = preserve input, 1 = perturb input.
Initialize the Premise with a perturbation key.
Parameters:
-
key(object) –The key used to generate the mask deterministically.
-
kwargs(dict[str, Any], default:{}) –Additional premise-specific arguments.
Source code in muppet/components/memory/base.py
Attributes
property
Retrieve mask from premise
Returns:
-
–
torch.key: mask following mask perturbation convention 1 for pertrubed features and 0 for non perturbed features
Functions
Check if the key has changed since last computation.
Generate explanation from the premise key (default: returns mask).
Memory
Bases: ABC
Abstract base class for memory structures in XAI exploration.
Manages the storage and retrieval of Premise objects, bridging the perturbation process by maintaining the state (premises) between phases.
Initialize the basic Memory structure.
Source code in muppet/components/memory/base.py
Functions
abstractmethod
PremiseList
Bases: Memory
Simple list-based memory implementation for storing premises.
Provides basic, in-memory sequential storage with efficient retrieval. Premises are replaced on new registration.
Initialize the PremiseList memory structure.
Source code in muppet/components/memory/base.py
Functions
Register a collection of premises, replacing any existing premises.
Parameters:
-
premises(Iterable[Premise]) –The premises to store in memory.
Source code in muppet/components/memory/base.py
muppet.components.memory.premise
Concrete Premise Implementations for MUPPET XAI Framework.
Provides specialized Premise classes for different perturbation strategies and data modalities in the MUPPET framework. Each class generates a concrete perturbation mask from an abstract key, supporting various XAI methods (e.g., LIME, SHAP, gradient-based).
Classes:
-
TimeStepPremise–Generates masks for time series data by masking a specific features at a specific timesteps, keep future timesteps in the mask as NaN (to enable different padding strategy).
-
BinaryRandomPremise–Creates random binary masks with configurable preservation probability, supporting both image and tabular data modalities.
-
SegmentedBinaryImagePremise–Generates masks based on image segmentation, allowing segment-wise perturbation using pre-computed segmentation maps.
-
GradientPremise–Provides optimizable masks for optimization-based perturbation strategies.
-
ConvolutionalFeaturePremise–Creates masks from convolutional neural network feature activations, enabling feature-maps based explanations.
-
FeaturesCombinationPremise–Generates masks from linear combinations of CNN features, supporting learnable feature selection and combination.
-
KeyBasedMaskPremise–Direct key-to-mask mapping for cases where keys are already in mask format, providing minimal processing overhead.
Supported Modalities
- Time Series: Timestep-based perturbations with temporal masking
- Images: Pixel-wise, segment-wise, and feature-based perturbations
- Tabular Data: Feature-wise random perturbations
- CNN Features: Activation-based and gradient-optimized perturbations
Mask Convention
All premise types follow the MUPPET binary mask convention: - 0: Preserve the original input value (no perturbation) - 1: Perturb the input value (apply perturbation strategy)
Technical Implementation Details
- Lazy evaluation with caching for computational efficiency
- Device-aware tensor management for GPU acceleration
- Deterministic mask generation from keys for reproducibility
- Support for both static, trainable and optimizable perturbation strategies
- Automatic mask resizing and interpolation for different input sizes
- Integration with scikit-image for advanced image processing operations
The premise implementations support the full MUPPET four-step process
- Explorer: Generates diverse premise keys for systematic perturbation
- Perturbator: Uses premise masks to create perturbed input variations
- Attributor: Stores model predictions and attribution results in premises
- Aggregator: Combines premise attributions into final explanations
Classes
TimeStepPremise
Bases: Premise
Premise for temporal perturbations in time series data.
Creates masks that perturb time series data by masking the current time step (mask=1) for perturbation. It sets future values to Nan to enable various padding strategie. Useful for explaining sequential models where temporal dependencies are important.
A timestep premise object.
Attributes:
-
key(tuple) –Holds the timestep index, and the mask shape. Enough to generate the corresponding mask. Expected structure: ({"timestep": int, "feature": int}, (num_features, sequence_length))
Source code in muppet/components/memory/premise.py
Functions
Generates a mask for a specific feature at a specific timestep.
The mask is a 2D tensor with an added batch dimension.
- The value at mask[feature, time_step] is set to 1.
- All values for timesteps greater than time_step are set to torch.nan.
- All other values are 0.
The key attribute is expected to be a tuple:
- self.key[0] (dict): A dictionary containing 'timestep' and 'feature' indices.
- self.key[1] (tuple): The shape of the mask (e.g., (number_of_features, sequence_length)).
Returns:
-
Tensor–torch.Tensor: A 3D tensor representing the mask, with shape (1, num_features, sequence_length). The first dimension is the batch size (1).
Source code in muppet/components/memory/premise.py
BinaryRandomPremise
Bases: Premise
Premise for generating random binary perturbation masks.
Creates binary masks with random patterns based on specified probability distributions. Supports both image and tabular data modalities with appropriate dimension handling.
Responsible for creating random binary masks with a probability \(p\) of notperturbing the corresponding pixel. When \(p\) is set to 1, the masks will have more 0 which means we preserve input example, when \(p\) is close to 0 the mask will be fuller and we will perturb more of the input.
Mask Convention
0: Preserve the input value 1: Perturb the input value
Parameters:
-
key(tuple) –Holds enough information in order to generate the mask.
-
modality(bool, default:'image') –Should we unsqueeze one more dimension at the beginning ? Default to True.
Interpretation
key[0]: mask dim before up-sampling. key[1]: probability \(p\) of preserving the input pixels: binary mask will have entry of 0 with probability \(p\). key[2]: the final mask shape (after up-scaling) (w, h) as the input example without batch nor channel dimensions.
Source code in muppet/components/memory/premise.py
Functions
Create the premise's random mask.
Returns:
-
Tensor–torch.tensor: The final random mask of shape (b=1, c=1, w, h)
Source code in muppet/components/memory/premise.py
SegmentedBinaryImagePremise
Bases: Premise
Premise for segment-based binary masks on images.
Generates binary masks based on image segmentation, where entire segments (superpixels) are either preserved or perturbed together. Provides more meaningful perturbations for image explanations.
Creates random binary masks with a probability \(p\) of keeping the corresponding pixel. When \(p\) is set to 1, it means mask everything in the input example, when 0 keeps everything in the input.
Mask Convention
0: Preserve the input value 1: Perturb the input value
Parameters:
-
key(Tensor) –Holds enough information in order to generate the mask.
-
segmented_example (torch.Tensor)–the segmented tensor of the explained example of shape (s, *example.shape).
-
modality(bool, default:'image') –modality of the data
Source code in muppet/components/memory/premise.py
Functions
Creates the premise's random binary mask.
Returns:
-
Tensor–torch.tensor: The final random mask of shape (b=1, c=1, h, w)
Source code in muppet/components/memory/premise.py
GradientPremise
Bases: Premise
Premise for gradient-based trainable masks.
Provides learnable mask parameters that can be optimized through gradient descent. Useful for iterative mask refinement methods that learn optimal perturbation patterns.
Provides a trainable mask.
Mask Convention
0: Preserve the input value 1: Perturb the input value
Parameters:
-
key(Tensor) –The premise's key. The key is expected to have gradients enabled.
-
upscaled_mask_shape(tuple) –The dimension to which the key is upscaled when transformed into a mask. 2-dimensional.
Interpretation
key: Representing the down-scaled trainable mask. Shape (1, small_w, small_h), mask: This maps between the down-scaled mask to its final shape (=x.shape) (1, small_w, small_h) => (1, w, h).
Source code in muppet/components/memory/premise.py
Attributes
Functions
Upscale the key's mask from (1, small_w, small_h) to the shape (1, w, h) where w=224 and h==224 are the width and hight of x, respectively.
Returns:
-
Tensor–torch.Tensor: The up-scaled mask
Source code in muppet/components/memory/premise.py
ConvolutionalFeaturePremise
Bases: Premise
Premise for creating masks from convolutional feature maps.
Generates discrete masks based on convolutional neural network feature activations. Each feature map is converted into a binary mask for targeted perturbations of specific learned features.
Responsible for creating discrete masks from convolutional features of the example.
Mask Convention
0: Preserve the input value 1: Perturb the input value
Parameters:
-
key(tuple) –Holds enough information in order to generate the mask.
Interpretation
key[0]: up-sampled activation of the convolutional layer (torch.Tensor) key[1]: channel (int)
Source code in muppet/components/memory/premise.py
Functions
Create the premise's heatmap from convolutional features.
Extracts a specific channel from the upsampled activations to create a visual explanation heatmap.
Returns:
-
Tensor–torch.Tensor: The up-sampled activation of shape (b=1, c=1, w, h) as a heatmap.
Source code in muppet/components/memory/premise.py
Creates the premise's mask.
Returns:
-
Tensor–torch.Tensor: The final feature mask of shape (b=1, c=1, w, h)
Source code in muppet/components/memory/premise.py
FeaturesCombinationPremise
Bases: GradientPremise
Premise for linear combinations of convolutional features.
Creates masks from weighted combinations of multiple convolutional feature maps. Allows for more complex perturbation strategies by combining different learned feature representations.
Responsible for creating discrete masks from a linear combination of convolutional features of the example.
Mask Convention
0: Preserve the input value 1: Perturb the input value
Parameters:
-
key(Tensor) –linear combination coefficients.
-
activations(Tensor) –activation of the last convolutional layer of the model.
-
upscaled_mask_shape(tuple[int]) –tuple of the form (w, h) with w the width and h the height of the input.
Source code in muppet/components/memory/premise.py
Functions
Creates the premise's mask. This is done by multiplying the previously-acquired feature maps in a combination based on the premise's key.
Recall that the key is the vector that is being optimized by the Explorer.
Returns:
-
Tensor–torch.tensor: The final feature mask of shape (b=1, c=1, w, h)
Source code in muppet/components/memory/premise.py
KeyBasedMaskPremise
Bases: Premise
Premise for generating masks based on tensor keys.
Creates perturbation masks using tensor keys as the basis for random mask generation. Provides reproducible mask patterns based on the key values and optional random seed.
Create the premise's random mask based on the key. Mask Convention: 0: Preserve the input value
1: Perturb the input value
Parameters:
-
- key(Tensor) –Tensor serving as the base for generating masks.
-
- seed(int) –Seed for the random number generator.
Source code in muppet/components/memory/premise.py
Functions
Create the premise's random mask based on the key.
Returns:
-
Tensor–torch.Tensor: The final random mask, which is a copy of the key.