Skip to content

Base Explainer Logic

muppet.explainers.base

Base explainer module for the MUPPET XAI framework.

This module provides the fundamental MuppetExplainer class that serves as the foundation for all explainable AI (XAI) methods in the MUPPET library. The base explainer implements the core four-step perturbation-based workflow that all MUPPET explainers follow:

  1. Explorer: Manage exploration strategies for perturbation by generating masks
  2. Perturbator: Applies perturbations to input data using the generated masks
  3. Attributor: Calculates attribution scores from model predictions on perturbed data
  4. Aggregator: Combines individual attributions into final explanations

The MUPPET framework decomposes XAI methods into these modular components that can be mixed and matched to create different explanation techniques. This modular design enables systematic comparison of XAI methods and facilitates the development of new explainers by combining existing components.

Classes:

  • MuppetExplainer

    The parent class that all specific explainers must inherit from.

Classes

MuppetExplainer
MuppetExplainer(
    model,
    explorer,
    perturbator,
    attributor,
    aggregator,
    memory=None,
)

Base explainer class that orchestrates the four-component MUPPET framework.

This class serves as the foundation for all explainable AI (XAI) methods in the MUPPET library, coordinating Explorer, Perturbator, Attributor, and Aggregator components to generate explanations for PyTorch models using perturbation-based methods. Supports multimodal data including images, tabular data, and time series.

The base explainer handles device management, component coordination, and the standard explanation workflow. Specific XAI methods like RISE, LIME, and SHAP inherit from this class and customize the four components to implement their respective algorithms.

The MUPPET framework decomposes XAI methods into modular components that can be mixed and matched to create different explanation techniques. This modular design enables systematic comparison of XAI methods and facilitates the development of new explainers by combining existing components.

Example

Creating a custom explainer by combining components:

from muppet.components.explorer.mask import RandomMasksExplorer
from muppet.components.perturbator.simple import SetToZeroPerturbator
from muppet.components.attributor import ClassScoreAttributor
from muppet.components.aggregator.mask import WeightedSumAggregator
from muppet.explainers.base import MuppetExplainer

# Define custom explainer by combining components
class CustomExplainer(MuppetExplainer):
    def __init__(self, model, nmasks=100):
        explorer = RandomMasksExplorer(nmasks=nmasks)
        perturbator = SetToZeroPerturbator()
        attributor = ClassScoreAttributor()
        aggregator = WeightedSumAggregator()
        super().__init__(model, explorer, perturbator, attributor, aggregator)
Note

This base class handles device management, component coordination, and the standard explanation workflow by implemeting the main call for all explainers. This method handles the main logic of the MUPPET XAI framework, it should not be overrided.

Initialize the base MUPPET explainer with the four core components.

Parameters:

  • model (Module) –

    The black-box torch model to be explained.

  • explorer (Explorer) –

    It generates what is called "Premise" which will represent the perturbation-element.

  • perturbator (Perturbator) –

    Responsible for perturbing the input example.

  • attributor (Attributor) –

    Responsible for calculating the attribution which could be the model output directly or anything else that will be aggregated to calculate the final explanation.

  • aggregator (Aggregator) –

    Responsible for aggregating the calculated attribution and providing the final explanation.

  • memory (Memory, default: None ) –

    A memory class where the "premises" are stored. E.g Tree, List, Set, ...

Source code in muppet/explainers/base.py
def __init__(
    self,
    model: torch.nn.Module,
    explorer: Explorer,
    perturbator: Perturbator,
    attributor: Attributor,
    aggregator: Aggregator,
    memory: Memory | None = None,
) -> None:
    """Initialize the base MUPPET explainer with the four core components.

    Args:
        model (torch.nn.Module): The black-box torch model to be explained.
        explorer (Explorer): It generates what is called "Premise" which will represent the perturbation-element.
        perturbator (Perturbator): Responsible for perturbing the input example.
        attributor (Attributor): Responsible for calculating the attribution which could be the model output directly or anything else
            that will be aggregated to calculate the final explanation.
        aggregator (Aggregator): Responsible for aggregating the calculated attribution and providing the final explanation.
        memory (Memory): A memory class where the "premises" are stored. E.g Tree, List, Set, ...
    """
    # register the custom components
    self.model = model
    self.explorer = explorer
    self.perturbator = perturbator
    self.attributor = attributor
    self.aggregator = aggregator
    if memory is None:
        self.memory = PremiseList()
    else:
        self.memory = memory

    # use cuda device otherwise use cpu
    self.device = DEVICE
    self.model.to(self.device)

    # share used device with all components
    self.explorer.device = self.device
    self.perturbator.device = self.device
    self.attributor.device = self.device
    self.aggregator.device = self.device
    self.memory.device = self.device

    # Prepare a spot for the premise_kwargs dict
    # This dict can be updated to have elements passed to Premises at their
    # creation : typically, by overcharging the __call__ function to compute
    # it before calling the original __call__ function. But it could also
    # have been modified by overcharging the __init__ instead to directly
    # store parameters : hence, we only create it if it does not already exist.
    if not hasattr(self, "premise_kwargs"):
        self.premise_kwargs = dict()
Functions
reinitialize
reinitialize()

Return the explainer to its original state.

Source code in muppet/explainers/base.py
def reinitialize(self):
    """Return the explainer to its original state."""
    self.memory.reinitialize()
    self.explorer.reinitialize()
    self.perturbator.reinitialize()
    self.attributor.reinitialize()
    self.aggregator.reinitialize()