Skip to content

Attributor

muppet.components.attributor.base

Base attributor component for MUPPET XAI.

This module defines the abstract base class for all attributors in the MUPPET XAI framework. Attributors are responsible for calculating attribution scores from model predictions on perturbed inputs. These attributions quantify how much each perturbation affects the model's output and serve as the basis for feature importance calculations.

The attribution process is the third step in the four-step perturbation-based XAI approach: generate masks → apply perturbations → calculate attributions → aggregate results.

Classes:

  • Attributor

    Abstract base class defining the interface for all attributor components.

Classes

Attributor
Attributor()

Bases: ABC

Abstract base class for attributor components in MUPPET XAI.

A global component that defines the 'calculate_attribution' method which is responsible for filling-up the premises' attribution. An attribution could be the model's output or something else that will be used by the aggregator to find the final heatmap.

Attributes:

  • device

    The used device. It gets updated from the main explainer after initialization.

  • convention

    The attribution convention (constructive or destructive).

  • allowed_conventions

    The allowed convention types from AttributionConvention.

Example

Typical usage involves subclassing the Attributor base class:

class CustomAttributor(Attributor):
    def __init__(self, convention="destructive"):
        self.convention = convention
        super().__init__()

    def calculate_attribution(self, x, perturbed_inputs, model, memory):
        # Calculate attributions and store in memory
        predictions = model(perturbed_inputs)
        # Custom attribution logic here
        pass

Initialize the attributor component.

Sets up the default device and attribution convention for the attributor. If no convention is set, defaults to 'destructive' with a warning.

Source code in muppet/components/attributor/base.py
def __init__(self) -> None:
    """Initialize the attributor component.

    Sets up the default device and attribution convention for the attributor.
    If no convention is set, defaults to 'destructive' with a warning.
    """
    self.device = DEVICE

    # Check that the convention has been set, if not set it with warning
    if not hasattr(self, "convention"):
        logger.warning(
            "Convention should be set for the Aggregator, by default let set it to 'destructive'"
        )
        self.convention = "destructive"

    super().__init__()
Attributes
convention property writable
convention

Get the attribution convention.

Functions
reinitialize
reinitialize()

Reset the attributor to its initial state.

This method restores the attributor to its original configuration, clearing any internal state or cached data that may affect subsequent attribution calculations.

Source code in muppet/components/attributor/base.py
def reinitialize(self):
    """Reset the attributor to its initial state.

    This method restores the attributor to its original configuration,
    clearing any internal state or cached data that may affect
    subsequent attribution calculations.
    """
    pass
calculate_attribution abstractmethod
calculate_attribution(x, perturbed_inputs, model, memory)

Calculates the attribution based on the example's perturbations (x'), the model and the original example (x) if needed.

It is generally expected that the most impactful masks will get the highest attribution score : indeed, at the end of the pipeline, most Aggregators will use a mask's attribution as a direct proxy for its importance.

Parameters:

  • x (Tensor) –

    The original input example. Shape (b=1, *x.shape[1:]): - b is the batch size - x.shape[1:] is the input data dimensions. E.g images (c, w, h) channels, width and height.

  • perturbed_inputs (Tensor) –

    The perturbations calculated by the Perturbator.

  • model (Module) –

    The black-box model.

  • memory (Memory) –

    The used memory structure.

Returns:

  • None ( None ) –

    It fills up the memory in place.

Source code in muppet/components/attributor/base.py
@abstractmethod
def calculate_attribution(
    self,
    x: torch.Tensor,
    perturbed_inputs: torch.Tensor,
    model: torch.nn.Module,
    memory: Memory,
) -> None:
    """Calculates the attribution based on the example's perturbations (x'), the model and the original example (x) if needed.

    It is generally expected that the most impactful masks will get the highest attribution score : indeed, at the end of the pipeline,
    most Aggregators will use a mask's attribution as a direct proxy for its importance.

    Args:
        x (torch.Tensor): The original input example. Shape (b=1, *x.shape[1:]):
            - b is the batch size
            - x.shape[1:] is the input data dimensions. E.g images (c, w, h) channels, width and height.

        perturbed_inputs (torch.Tensor): The perturbations calculated by the Perturbator.

        model (torch.nn.Module): The black-box model.

        memory (Memory): The used memory structure.

    Returns:
        None: It fills up the memory in place.

    """
    raise NotImplementedError

muppet.components.attributor.classification

Classification-based attributors for MUPPET XAI.

This module provides attribution methods for classification models. These attributors calculate attribution scores based on class probabilities, making them ideal for explaining image classification, text classification, and other discrete classification tasks.

Classes:

  • ClassScoreAttributor

    Calculates attributions based on non perturbed input class probability prediction score, measuring how much each perturbation affects the model's confidence in the correct prediction. Supports both destructive and constructive attribution conventions.

Technical Details

The ClassScoreAttributor computes attributions by: 1. Determining the true class from the original input 2. Evaluating the model's probability for this class on each perturbation 3. Converting probabilities to attribution scores based on the convention: - Destructive: Higher scores for perturbations that reduce class confidence - Constructive: Higher scores for perturbations that maintain class confidence

This method is computationally efficient and provides intuitive explanations for classification models across various domains and architectures.

Classes

ClassScoreAttributor
ClassScoreAttributor(convention='destructive')

Bases: Attributor

Attribution based on probability score of the true class for classification tasks.

This attributor calculates the probability score of the true class (calculated from the original example) and stores it into premise's attribution. Since in most cases we expect the most impactful perturbations to have the highest attribution, by default the attribution will be MINUS the probability of the true score.

Attributes:

  • true_class

    True class index determined from the original input.

Initialize the class score attributor.

Parameters:

  • convention (Union[AttributionConvention, str], default: 'destructive' ) –

    The attribution convention, either 'destructive' or 'constructive'.

Source code in muppet/components/attributor/classification.py
def __init__(
    self, convention: Union[AttributionConvention, str] = "destructive"
) -> None:
    """Initialize the class score attributor.

    Args:
        convention: The attribution convention, either 'destructive' or 'constructive'.
    """
    self.true_class = None
    self.convention = convention
    super().__init__()
Functions
reinitialize
reinitialize()

Reinitialize the classification attributor.

Source code in muppet/components/attributor/classification.py
def reinitialize(self):
    """Reinitialize the classification attributor."""
    self.true_class = None
    return super().reinitialize()
calculate_attribution
calculate_attribution(x, perturbed_inputs, model, memory)

Calculates the attribution of perturbed inputs.

Parameters:

  • x (Tensor) –

    Example to explain. (b=1, *x.shape[1:]).

  • perturbed_inputs (Tensor) –

    The example's perturbations. Shape (N, *x.shape).

  • model (Module) –

    The black-box model.

  • memory (Premiselist) –

    Premises' memory where to save the attributions.

Where b is the batch size (=1), N is the number of generated masks.

Source code in muppet/components/attributor/classification.py
def calculate_attribution(
    self,
    x: torch.Tensor,
    perturbed_inputs: torch.Tensor,
    model: torch.nn.Module,
    memory: PremiseList,
) -> None:
    """Calculates the attribution of perturbed inputs.

    Args:
        x (torch.Tensor): Example to explain. (b=1, *x.shape[1:]).
        perturbed_inputs (torch.Tensor): The example's perturbations. Shape (N, *x.shape).
        model (torch.nn.Module): The black-box model.
        memory (Premiselist): Premises' memory where to save the attributions.

    Where b is the batch size (=1), N is the number of generated masks.

    """
    # calculate the true class if not already done
    if self.true_class is None:
        with torch.no_grad():
            true_output = F.softmax(model(x).detach(), dim=1)
        self.true_class = torch.argmax(true_output, dim=1)  # (b=1)

    # calculate the true class prediction of every perturbation
    for idx, premise in enumerate(memory.get_premises()):
        with torch.no_grad():
            logits = model(perturbed_inputs[idx].float()).detach()

        probs = F.softmax(logits, dim=1)  # (b=1, nclasses)
        premise.attribution = probs[:, self.true_class].squeeze(
            dim=-1
        )  # (b=1, 1) => (b)

        if self.convention == AttributionConvention.DESTRUCTIVE:
            premise.attribution = -premise.attribution

    return

muppet.components.attributor.differentiable

Gradient-based and differentiable attributors for MUPPET XAI.

This module provides attribution methods that leverage gradient-based optimization and differentiable loss functions. These attributors are designed for scenarios where the attribution interacts with a optimisation-based exploration techniques.

Classes:

  • DifferentiableAttributor

    Abstract base class for differentiable attributors that use customizable loss functions to compute attributions through backpropagation.

  • MaskRegularizedScoreAttributor

    Concrete implementation that combines classification scores with mask regularization terms (L1 and Total Variation) to find minimal and smooth explanatory masks.

Classes

DifferentiableAttributor
DifferentiableAttributor()

Bases: Attributor

Base class for gradient-based attribution methods.

This class is used alongside gradient-based exploration methods. It loops through the premises to fill up their attributions by calling a customizable loss function. All needed arguments for the loss calculation must be initialized within the child class.

This base class sets up the true class placeholder for gradient-based attribution methods that can benefit from backpropagation and differentiable loss functions.

Attributes:

  • true_class

    The true class index calculated once from the original input x.

Initialize the differentiable attributor.

Source code in muppet/components/attributor/differentiable.py
def __init__(self) -> None:
    """Initialize the differentiable attributor."""
    self.true_class = None
    super().__init__()
Functions
reinitialize
reinitialize()

Return DifferentiableAttributor to its original state.

Source code in muppet/components/attributor/differentiable.py
def reinitialize(self):
    """Return DifferentiableAttributor to its original state."""
    self.true_class = None
calculate_attribution
calculate_attribution(x, perturbed_inputs, model, memory)

Calculates the loss of an objective function defined by calculate_attribution_loss method.

Parameters:

  • x (Tensor) –

    Example to explain. Shape (1, *x.shape[1:]) E.g (b=1, c, w, h) for images - b is number of input examples, - c is the channel dimensions, - w is the width, - h is the height,

  • perturbed_inputs (Tensor) –

    Perturbed versions of the example. Shape (N, x.shape) E.g (N, b, c, w, h) - N is the number of applied perturbations on the example.

  • model (Module) –

    The black-box model.

  • memory (FlatList) –

    Structure holding the premises where attributions will be saved.

Source code in muppet/components/attributor/differentiable.py
def calculate_attribution(
    self,
    x: torch.Tensor,
    perturbed_inputs: torch.Tensor,
    model: torch.nn.Module,
    memory: Memory,
) -> None:
    """Calculates the loss of an objective function defined by `calculate_attribution_loss` method.

    Args:
        x (torch.Tensor): Example to explain. Shape (1, *x.shape[1:]) E.g (b=1, c, w, h) for images
            - b is number of input examples,
            - c is the channel dimensions,
            - w is the width,
            - h is the height,
        perturbed_inputs (torch.Tensor): Perturbed versions of the example. Shape (N, x.shape) E.g (N, b, c, w, h)
            - N is the number of applied perturbations on the example.
        model (torch.nn.Module): The black-box model.
        memory (FlatList): Structure holding the premises where attributions will be saved.
    """
    # get premises from memory
    premises = memory.get_premises()

    # calculate the true class once
    if self.true_class is None:
        with torch.no_grad():
            true_output = torch.nn.Softmax(dim=1)(model(x)).to(
                self.device
            )  # (b, nclasses)

        self.true_class = torch.argmax(true_output, dim=1)  # (b)

    probs = torch.nn.Softmax(dim=1)(
        model(perturbed_inputs[:, 0, :])  # number of masks per input is 1
    )  # (b, 1000)

    for i in range(len(premises)):
        loss = self.calculate_attribution_loss(
            premise=premises[i], output=probs[i, self.true_class[i]]
        )

        # save the loss to premise's attribution
        premises[i].attribution = loss.to(self.device)

    return
calculate_attribution_loss abstractmethod
calculate_attribution_loss(premise, output)

Calculates the optimization loss using premise element and model's output corresponding to the predicted class from original input example.

Parameters:

  • premise (Premise) –

    The memory's element that represent the perturbation.

  • output (Tensor) –

    The model's output for the corresponding input example.

Raises:

  • NotImplementedError

    Must be implemented in child classes.

Source code in muppet/components/attributor/differentiable.py
@abstractmethod
def calculate_attribution_loss(
    self, premise: Premise, output: torch.Tensor
) -> torch.Tensor:
    """Calculates the optimization loss using premise element and model's output corresponding to the predicted class from original input example.

    Args:
        premise (Premise): The memory's element that represent the perturbation.

        output (torch.Tensor): The model's output for the corresponding input example.

    Raises:
        NotImplementedError: Must be implemented in child classes.

    """
    raise NotImplementedError
MaskRegularizedScoreAttributor
MaskRegularizedScoreAttributor(
    l1_coeff=0,
    tv_coeff=0,
    tv_beta=0,
    convention="destructive",
)

Bases: DifferentiableAttributor

Regularized mask attribution using L1 and total variation loss.

This attributor calculates a loss function combining minimal mask penalty, total variation denoising, and true class probability from the perturbed input: Loss = λ|m| + λ'|tv(1-m)| + f(x') By default no regularization is applied on the mask.

Attributes:

  • l1_coeff

    L1 regularization coefficient for mask sparsity.

  • tv_coeff

    Total Variation coefficient for smoothness regularization.

  • tv_beta

    Degree of the Total Variation denoising norm.

  • convention

    The attribution convention (constructive or destructive).

  • true_class

    The true class index calculated once from the original input x.

Initialize the mask regularized score attributor.

Parameters:

  • l1_coeff (float, default: 0 ) –

    L1 regularization coefficient for the mask.

  • tv_coeff (float, default: 0 ) –

    Total variation regularization coefficient for the mask.

  • tv_beta (float, default: 0 ) –

    Beta parameter for total variation calculation.

  • convention (Union[AttributionConvention, str], default: 'destructive' ) –

    Attribution convention, either 'destructive' or 'constructive'.

Source code in muppet/components/attributor/differentiable.py
def __init__(
    self,
    l1_coeff: float = 0,
    tv_coeff: float = 0,
    tv_beta: float = 0,
    convention: Union[AttributionConvention, str] = "destructive",
) -> None:
    """Initialize the mask regularized score attributor.

    Args:
        l1_coeff: L1 regularization coefficient for the mask.
        tv_coeff: Total variation regularization coefficient for the mask.
        tv_beta: Beta parameter for total variation calculation.
        convention: Attribution convention, either 'destructive' or 'constructive'.
    """
    self.l1_coeff = l1_coeff
    self.tv_coeff = tv_coeff
    self.tv_beta = tv_beta

    self.true_class = None
    self.convention = convention
    super().__init__()
Functions
calculate_attribution_loss
calculate_attribution_loss(premise, output)

Calculates the attribution/loss from the sum of the remise's mask mean, mask's TV norm and probability prediction corresponding to the true class.

Parameters:

  • premise

    The premise element representing the perturbation.

  • output

    The true class predicted probability.

Returns:

  • loss ( Tensor ) –

    The calculated loss.

Source code in muppet/components/attributor/differentiable.py
def calculate_attribution_loss(self, premise, output) -> torch.Tensor:
    """Calculates the attribution/loss from the sum of the remise's mask mean, mask's TV norm and probability prediction
    corresponding to the true class.

    Args:
        premise: The premise element representing the perturbation.

        output: The true class predicted probability.

    Returns:
        loss: The calculated loss.

    """
    mask_loss = 0
    tv_denoise_loss = 0

    if self.l1_coeff != 0:
        mask_loss = self.l1_coeff * torch.mean(
            torch.abs(premise.key)
        )  # (b=1)

    if self.tv_coeff != 0:
        tv_denoise_loss = self.tv_coeff * self._tv_norm(
            (1 - premise.key), self.tv_beta
        )  # (b=1)

    if self.convention == "descructive":
        score_impact = output
    else:
        score_impact = -output

    return mask_loss + tv_denoise_loss + score_impact

muppet.components.attributor.distribution

Probability distribution-based attributors for MUPPET XAI.

This module provides attribution methods that analyze changes in probability distributions over time, particularly designed for time series and sequential data explanation. These attributors measure how perturbations affect the model's distributional predictions and temporal dynamics.

Classes:

  • ProbaShiftAttributor

    Calculates attributions based on the difference between temporal distribution shifts and perturbation-induced distribution changes, implementing the FIT (Feature Importance in Time) methodology for time series explanation.

Classes

ProbaShiftAttributor
ProbaShiftAttributor(padding)

Bases: Attributor

Attribution based on probability distribution shifts for time series classification.

The distribution-based attributors are especially valuable for understanding sequential models where the temporal evolution of predictions is as important as the final output. They quantify feature importance by analyzing distributional shifts caused by perturbations.

This attributor works on probability distributions over classes for classification tasks. The attribution is calculated as the difference between \(KL(P(y|X_{0:t}) || P(y|X_{0:t-1}))\) and \(KL(P(y|X_{0:t}) || P(y|X'_{0:t}))\) summed over all classes, where \(X'_{0:t}\) means the values of features are perturbed at time t.

This attributor calculates feature importance based on probability distribution shifts over temporal sequences, implementing the FIT methodology for time series explanation.

Attributes:

  • outputs

    Key-value mapping of timestep to model's output when calculating P(y|X_0:t).

  • padding

    The padding strategy for time series inputs.

  • convention

    The attribution convention (destructive).

Initialize the ProbaShiftAttributor.

Parameters:

  • padding (str) –

    Padding strategy for sequences. Options: - "left": Zero-pad sequences on the left (common for RNNs) - "right": Zero-pad sequences on the right - None: No padding for models handling variable lengths

Source code in muppet/components/attributor/distribution.py
def __init__(self, padding: str) -> None:
    """Initialize the ProbaShiftAttributor.

    Args:
        padding: Padding strategy for sequences. Options:
            - "left": Zero-pad sequences on the left (common for RNNs)
            - "right": Zero-pad sequences on the right
            - None: No padding for models handling variable lengths
    """
    self.padding = padding
    self.outputs = {}  # key: t, value: p(y/ x0:t)
    self.convention = "destructive"
    # The attribution computed as the difference explained in the doc string is high
    # when the perturbation impacts the model score on the selected class.
    super().__init__()
Functions
calculate_attribution
calculate_attribution(
    x, perturbed_inputs, model, memory=PremiseList
)

For every premise stored in the memory, fills up its attribution calculated from

\[\sum_{ ext{over all classes}}KL(P(y/X_{0:t}) || P(y/X_{0:t-1})) - \sum_{ ext{over all classes}} KL(P(y/X_{0:t}) || P(y/X'_{0:t}))\]

Parameters:

  • x (Tensor) –

    The input example to be explained. Shape (b=1, f, t)

  • perturbed_inputs (Tensor) –

    The calculated perturbations by the Perturbator. Shape (N, *x.shape)

  • model (Module) –

    The black-box model.

  • memory (Memory, default: PremiseList ) –

    The simple list memory structure.

Source code in muppet/components/attributor/distribution.py
def calculate_attribution(
    self,
    x: torch.Tensor,
    perturbed_inputs: torch.Tensor,
    model: torch.nn.Module,
    memory: Memory = PremiseList,
) -> None:
    """For every premise stored in the memory, fills up its attribution calculated from

    $$\sum_{\text{over all classes}}KL(P(y/X_{0:t}) || P(y/X_{0:t-1})) - \sum_{\text{over all classes}} KL(P(y/X_{0:t}) || P(y/X'_{0:t}))$$

    Args:
        x (torch.Tensor): The input example to be explained. Shape (b=1, f, t)
        perturbed_inputs (torch.Tensor): The calculated perturbations by the Perturbator. Shape (N, *x.shape)
        model (torch.nn.Module): The black-box model.
        memory (Memory, optional): The simple list memory structure.

    """
    idx_groups = {}
    premises_groups = {}

    for idx, premise in enumerate(memory.get_premises()):
        group_key = premise.key[0]["timestep"], premise.key[0]["feature"]
        idx_groups[group_key] = idx_groups.get(group_key, []) + [idx]
        premises_groups[group_key] = premises_groups.get(group_key, []) + [
            premise
        ]

    for ((time_step, feature), indexes), (_, premises) in zip(
        idx_groups.items(), premises_groups.items()
    ):
        # get original x0:t
        xt = x[:, :, : time_step + 1]  # (b, f, t)
        # get original X0:t-1
        xt_1 = xt[:, :, :-1]
        # get perturbed x at t (0:t-1 values of x, at t:=sampled value, then t+1:end Nans ignore them)
        xt_hat = perturbed_inputs[
            indexes, :, :, : time_step + 1
        ]  # (N, b, f, t)

        # if padding is provided ==> pad
        if self.padding:
            # number of values to pad for xt and xt_hat
            num_pad_values = abs(time_step + 1 - x.shape[2])

            xt = self._padd(x=xt, num_pad_values=num_pad_values, value=0)
            xt_1 = self._padd(
                x=xt_1, num_pad_values=num_pad_values + 1, value=0
            )
            xt_hat = self._padd(
                x=xt_hat, num_pad_values=num_pad_values, value=0
            )

        with torch.no_grad():
            try:
                model_output = model(xt)
            except ValueError as e:
                if "This classifier cannot handle unequal length" in str(
                    e
                ):  # AEON model exception
                    logger.exception(
                        "The model cannot handle unequal length, consider using a padding strategy. "
                        "Not that FIT explainer might ont be the most appropriate for these models. "
                    )
                raise e
            kl_epsilon = 1e-6
            output_xt = self.outputs.get(
                time_step, model_output.detach()
            )  # p(y/ x0:t) (b, c) c:number of output classes
            output_xt_1 = self.outputs.get(
                time_step - 1, model(xt_1).detach()
            )  # p(y/ x0:t-1) (b, c)
            output_xt_log = torch.log(output_xt + kl_epsilon)
            output_xt_1_log = torch.log(output_xt_1 + kl_epsilon)
            # mean here for the monte carlo estimation of $p(y|X_{O..t-1}, X_{S,t})$
            try:
                output_xt_hat = model(xt_hat)
            except TypeError:
                logger.debug(
                    "The prediction model do not handle batched data"
                )
                output_xt_hat = torch.stack(list(map(model, xt_hat)))

            output_xt_hat_log = torch.log(
                output_xt_hat + kl_epsilon
            )  # p(y/ x'0:t)

        # fill the attributions dicts if not already done
        if time_step not in self.outputs.keys():
            self.outputs[time_step] = output_xt

        if (time_step - 1) not in self.outputs.keys():
            self.outputs[time_step - 1] = output_xt_1

        # TODO in terms of logic these KL computations should be moved to the proba aggregator (which should also be renamed).
        # calculate KL(p(xt)||p(y/xt-1)) as term1
        temporal_distribution_shift = torch.nn.KLDivLoss(
            reduction="batchmean", log_target=True
        )(output_xt_1_log, output_xt_log)

        # calculate KL(p(xt)||p(y/x't)) as term2
        unexplained_distribution_shift = torch.nn.KLDivLoss(
            reduction="batchmean", log_target=True
        )(output_xt_hat_log, output_xt_log)

        # fill up the premise's attribution
        FIT_importance_score_for_S = (
            temporal_distribution_shift - unexplained_distribution_shift
        )  # (b)

        # FIT importance score represent the hability of the complementary of
        # the perturbed feature to explain the temporal shift of the prediction
        # by taking the opposite we obtain the importance score for the pertubed feature.
        # FIT score > 0 the the complementary of the {perturbed feature} explains the temporal shift
        # FIT score = 0 the {all_features} - {perturbed feature} explains the temporal shift
        # FIT score < 0 the all_features - {perturbed feature} explains the temporal shift
        for premise in premises:
            # Apply the distributionnal aggregated score to all premise for the given (feature, timestep) couple
            premise.attribution = -FIT_importance_score_for_S

    return

muppet.components.attributor.embedding

Embedding-based attributors for MUPPET XAI.

This module provides attribution methods that work with embedding spaces and latent representations. These attributors are designed for models that output vector embeddings rather than discrete classifications, making them ideal for explaining representation learning models, autoencoders, and embedding-based systems.

Classes:

  • EmbeddingDistanceAttributor

    Calculates attributions based on the L2 distance between original and perturbed embeddings, measuring how much each perturbation changes the model's internal representation of the input.

  • DiceScoreAttributor

    Specialized attributor for semantic segmentation task that uses Dice coefficient to measure segmentation quality changes caused by perturbations instead of the Lé distance of the vector embedding

Classes

EmbeddingDistanceAttributor
EmbeddingDistanceAttributor()

Bases: Attributor

Attribution based on distance between original and perturbed embeddings.

A perturbation's value is equal to how much it changed the original embedding (i.e., original model output). The end goal is to find perturbations that make the perturbed embedding as far away from the original embedding as possible.

This attributor measures how perturbations affect the model's embedding representations by calculating L2 distances between original and perturbed embeddings. The model output is expected to have shape (batch, **embedding_dim).

The EmbeddingDistanceAttributor computes attributions by: 1. Reference computation: Computing the original input's embedding E₀ = model(x) 2. Distance measurement: For each perturbation xᵢ, computing Eᵢ = model(xᵢ) 3. Attribution scoring: Computing distance d(E₀, Eᵢ) = ||E₀ - Eᵢ||₂ 4. Sign adjustment: Applying negative sign to maximize distance (destructive convention)

The L2 distance provides a natural measure of representation change:

Attribution = -||embedding_original - embedding_perturbed||₂

The method works with any model that outputs continuous vector representations, regardless of the embedding dimensionality or architecture (CNNs, transformers, autoencoders, etc.).

Attributes:

  • input_embedding

    Stores the true embedding output for comparison.

Initialize the EmbeddingDistanceAttributor.

Source code in muppet/components/attributor/embedding.py
def __init__(self) -> None:
    """Initialize the EmbeddingDistanceAttributor."""
    self.convention = "destructive"

    self.input_embedding = None  # Will be used to store true output later
    super().__init__()
Functions
calculate_attribution
calculate_attribution(x, perturbed_inputs, model, memory)

Calculate the L2 distance as the attribution between x and its perturbations. Note that the expected shape for x and perturbed_inputs (1 for batch, nb_rows, embedding_dim).

Parameters:

  • x (Tensor) –

    The input example to be explained.

  • perturbed_inputs (Tensor) –

    The calculated perturbations by the Perturbator.

  • model (Module) –

    The black-box model.

  • memory (Memory) –

    The simple list memory structure.

Source code in muppet/components/attributor/embedding.py
def calculate_attribution(
    self,
    x: torch.Tensor,
    perturbed_inputs: torch.Tensor,
    model: torch.nn.Module,
    memory: Memory,
) -> None:
    """Calculate the L2 distance as the attribution between x and its perturbations.
    Note that the expected shape for x and perturbed_inputs (1 for batch, nb_rows, embedding_dim).

    Args:
        x (torch.Tensor): The input example to be explained.

        perturbed_inputs (torch.Tensor): The calculated perturbations by the Perturbator.

        model (torch.nn.Module): The black-box model.

        memory (Memory, optional): The simple list memory structure.

    """
    # Calculate the original example's embedding if not already done
    if self.input_embedding is None:
        with torch.no_grad():
            self.input_embedding = model(x).detach()

    # For each premise we currently focus on in this step,
    # save the distance btwn example and its perturbed version
    for idx, premise in enumerate(memory.get_premises()):
        input_reshaped = perturbed_inputs[
            idx
        ].float()  # perturbed_inputs (N, **x.shape)

        embedding = model(
            input_reshaped
        )  # .detach() # Careful not to detach

        dist = self.similarity(embedding, self.input_embedding)
        premise.attribution = dist

    return
similarity
similarity(embedding, true_embedding)

Calculate similarity between perturbed and original embeddings.

Computes the negative L2 distance between embeddings to maximize the distance (higher score for more different embeddings).

Parameters:

  • embedding (Tensor) –

    The perturbed embedding.

  • true_embedding (Tensor) –

    The original input embedding.

Returns:

  • torch.Tensor: Negative L2 distance (higher values indicate more difference).

Source code in muppet/components/attributor/embedding.py
def similarity(self, embedding, true_embedding):
    """Calculate similarity between perturbed and original embeddings.

    Computes the negative L2 distance between embeddings to maximize
    the distance (higher score for more different embeddings).

    Args:
        embedding (torch.Tensor): The perturbed embedding.
        true_embedding (torch.Tensor): The original input embedding.

    Returns:
        torch.Tensor: Negative L2 distance (higher values indicate more difference).
    """
    assert true_embedding.shape == embedding.shape, (
        f"True embedding and embedding must have the same shape. Instead we got {true_embedding.shape} and {embedding.shape}"
    )

    inaccuracy_term = self._compo_similarity(
        embedding.reshape(embedding.shape[0], -1),
        true_embedding.reshape(true_embedding.shape[0], -1),  # (b, 1)
    )

    final_dist = (
        -1 * inaccuracy_term
    )  # We want to MAXIMIZE the distance, so multiply inaccuracy term by -1 !
    return final_dist
DiceScoreAttributor
DiceScoreAttributor()

Bases: Attributor

Attribution based on Dice score between probability distributions.

This attributor calculates the Dice score between the predicted probability distribution of a perturbed input and the original example's output. The Dice score measures the overlap between the two distributions, providing a similarity measure for classification outputs.

This attributor is specifically designed for segmentation tasks where it measures how perturbations affect segmentation quality by calculating Dice coefficient changes between original and perturbed predictions.

Attributes:

  • true_class

    The true class index calculated from the original input.

Initialize the DiceScoreAttributor.

Inferit from Attributor with true_class is initiated to None

Source code in muppet/components/attributor/embedding.py
def __init__(self) -> None:
    """Initialize the DiceScoreAttributor.

    Inferit from `Attributor` with `true_class` is initiated to None
    """
    self.true_class = None
    super().__init__()
Functions
reinitialize
reinitialize()

Reset the attributor to its initial state.

Clears the cached true class to ensure fresh calculations for new inputs.

Source code in muppet/components/attributor/embedding.py
def reinitialize(self):
    """Reset the attributor to its initial state.

    Clears the cached true class to ensure fresh calculations
    for new inputs.
    """
    self.true_class = None
    return super().reinitialize()
calculate_attribution
calculate_attribution(x, perturbed_inputs, model, memory)

Calculates the attribution (Dice score) for perturbed inputs.

Parameters:

  • x (Tensor) –

    Example input. Shape: (1, C, H, W)

  • perturbed_inputs (Tensor) –

    Perturbed inputs. Shape: (N, 1, C, H, W)

  • model (Module) –

    Model to explain.

  • memory (PremiseList) –

    Memory structure to attach attributions to.

Source code in muppet/components/attributor/embedding.py
def calculate_attribution(
    self,
    x: torch.Tensor,
    perturbed_inputs: torch.Tensor,
    model: torch.nn.Module,
    memory: PremiseList,
) -> None:
    """Calculates the attribution (Dice score) for perturbed inputs.

    Args:
        x (torch.Tensor): Example input. Shape: (1, C, H, W)
        perturbed_inputs (torch.Tensor): Perturbed inputs. Shape: (N, 1, C, H, W)
        model (torch.nn.Module): Model to explain.
        memory (PremiseList): Memory structure to attach attributions to.
    """
    if self.true_class is None:
        with torch.no_grad():
            reference_output = model(x)  # (1, C, H, W)
            true_class = torch.argmax(reference_output, dim=1)  # (1, H, W)
            true_one_hot = (
                F.one_hot(true_class, num_classes=reference_output.shape[1])
                .permute(0, 3, 1, 2)
                .float()
            )  # (1, 21, H, W)

    for idx, premise in enumerate(memory.get_premises()):
        with torch.no_grad():
            logits = model(perturbed_inputs[idx].float())  # (1, C, H, W)
            pred_class = torch.argmax(logits, dim=1)  # (1, H, W)
            pred_one_hot = (
                F.one_hot(pred_class, num_classes=reference_output.shape[1])
                .permute(0, 3, 1, 2)
                .float()
            )  # (1, 21, H, W)

        intersection = (true_one_hot * pred_one_hot).sum(
            dim=[0, 2, 3]
        )  # sum over batch & spatial dims => (21,)
        union = true_one_hot.sum(dim=[0, 2, 3]) + pred_one_hot.sum(
            dim=[0, 2, 3]
        )  # (21,)

        perturbation_importance = (
            2 * (intersection) / (union + 10e-12)
        )  # (21,)
        attribution = (
            1 - perturbation_importance[1:].mean()
        )  # average over classes

        premise.attribution = attribution.unsqueeze(0)

    return

muppet.components.attributor.similarity

Similarity-based attributors for MUPPET XAI.

This module provides attribution methods that incorporate similarity measures between original and perturbed inputs. These attributors are essential for local explanation methods like LIME and SHAP, where the importance of perturbations is weighted by their similarity to the original input.

Classes:

  • SimilarityAttributor

    Generic attributor that combines model predictions with configurable similarity functions for flexible local explanation methods.

Functions:

Classes

SimilarityAttributor
SimilarityAttributor(similarity_fun)

Bases: Attributor

Attribution based on similarity measures between original and perturbed inputs.

This attributor calculates similarities relative to a provided similarity function. The similarity function returns high values when inputs are highly different, making it suitable for LIME-style explanations where we need to weight samples by their distance from the original input.

Similarity-based attribution combines two key components: 1. Model response: How the model's prediction changes with perturbation 2. Input similarity: How similar the perturbation is to the original input

The SimilarityAttributor stores both values:

premise.attribution = {
    "attribution": model_prediction_change,
    "similarity": similarity_score
}

LIME Similarity: Uses Gaussian kernel with Euclidean distance:

similarity = exp(-distance²/σ²)

SHAP Kernel: Based on coalition size with theoretical guarantees:

weight = (M-1) / (C(M,|S|) × |S| × (M-|S|))
Where M is total features and |S| is coalition size.

Dice Score: For segmentation, measures overlap between predictions:

Dice = 2×|intersection| / (|pred| + |true|)

These methods are particularly effective for: - Local explanations: LIME and SHAP-style interpretability - Faithful approximations: Ensuring explanations reflect local model behavior - Segmentation analysis: Understanding model performance on different regions - Coalition-based methods: Game-theoretic explanation approaches

Attributes:

  • predicted_class

    The predicted class from the original input.

  • similarity_fun

    The similarity function used for calculations.

  • convention

    The attribution convention (perturbed_input_similarity).

Example

Using LIME-style similarity weighting:

# Initialize with LIME similarity function
attributor = SimilarityAttributor(similarity_fun=lime_similarity)

# Use in LIME explainer
explainer = LIMEExplainer(
    model=image_classifier,
    attributor=attributor,
    # ... other components
)

explanation = explainer.explain(image_tensor)

Initialize the SimilarityAttributor.

Parameters:

  • similarity_fun (Callable[[Tensor, Tensor, Premise], Tensor]) –

    Function that calculates similarity between original and perturbed inputs. Takes (original_tensor, perturbed_tensor, premise) and returns similarity scores. Higher values indicate greater difference.

Source code in muppet/components/attributor/similarity.py
def __init__(
    self,
    similarity_fun: Callable[
        [torch.Tensor, torch.Tensor, "Premise"], torch.Tensor
    ],
) -> None:
    """Initialize the SimilarityAttributor.

    Args:
        similarity_fun: Function that calculates similarity between original and
            perturbed inputs. Takes (original_tensor, perturbed_tensor, premise)
            and returns similarity scores. Higher values indicate greater difference.
    """
    self.similarity_fun = similarity_fun
    self.predicted_class = None

    # similarity function return high values when input are highly different
    self.convention = "perturbed_input_similarity"
    super().__init__()
Functions
calculate_attribution
calculate_attribution(x, perturbed_inputs, model, memory)

Calculates the attribution of perturbed inputs.

Parameters:

  • - x (Tensor) –

    Example to explain. (b=1, *x.shape[1:]).

  • - perturbed_inputs (Tensor) –

    The example's perturbations. Shape (N, *x.shape).

  • - model (Module) –

    Given model we want to explain.

  • - memory (Memory) –

    Memory where the premises are stored.

Where b is the batch size (=1), N is the number of generated masks.

Source code in muppet/components/attributor/similarity.py
def calculate_attribution(
    self,
    x: torch.Tensor,
    perturbed_inputs: torch.Tensor,
    model: torch.nn.Module,
    memory: Memory,
) -> None:
    """Calculates the attribution of perturbed inputs.

    Args:
        - x (torch.Tensor): Example to explain. (b=1, *x.shape[1:]).
        - perturbed_inputs (torch.Tensor): The example's perturbations. Shape (N, *x.shape).
        - model (torch.nn.Module): Given model we want to explain.
        - memory (Memory): Memory where the premises are stored.

    Where b is the batch size (=1), N is the number of generated masks.
    """
    perturbed_inputs_vectors = perturbed_inputs.view(
        perturbed_inputs.shape[0], -1
    )
    x_vector = x.view(1, -1)

    # Calculate the true class if not already done
    if self.predicted_class is None:
        with torch.no_grad():
            true_output = F.softmax(
                model(x).detach(), dim=1
            )  # (b, nclasses)
        self.predicted_class = torch.argmax(true_output, dim=1)  # (b=1)

    # Calculate the true class prediction of every perturbation
    for idx, premise in enumerate(memory.get_premises()):
        with torch.no_grad():
            logits = model(perturbed_inputs[idx].float()).detach()
        probs = F.softmax(logits, dim=1)  # (b=1, nclasses)

        final = probs[:, self.predicted_class].squeeze(dim=-1)

        # Compute similarity for the current premise
        similarities = self.similarity_fun(
            x_vector, perturbed_inputs_vectors, premise
        )

        if (
            similarities.shape[0] <= 1
        ):  # Shape is [1] or [] indicating scalar
            similarity = similarities
        else:  # Assume it's a vector, and we need to index it
            similarity = similarities[idx]

        premise.attribution = {
            "attribution": final,
            "similarity": similarity.item(),  # Ensure it's a scalar for consistency
        }

    return

Functions

lime_similarity
lime_similarity(x_vector, perturbed_vector, premise)

Example LIME similarity function using a Gaussian kernel for LIME method. This function computes similarities for all perturbations at once.

Source code in muppet/components/attributor/similarity.py
def lime_similarity(
    x_vector: torch.Tensor, perturbed_vector: torch.Tensor, premise: "Premise"
) -> torch.Tensor:
    """Example LIME similarity function using a Gaussian kernel for LIME method.
    This function computes similarities for all perturbations at once.
    """
    distances = torch.cdist(x_vector, perturbed_vector).view(-1)
    min_value = distances.min()
    max_value = distances.max()

    if min_value == max_value:
        distances = distances * 0
    else:
        distances = (distances - min_value) / (max_value - min_value)

    kernel = torch.sqrt(torch.exp(-(distances**2) / 0.25**2))
    return kernel
kernel_shap_similarity
kernel_shap_similarity(x_vector, perturbed_vector, premise)

Calculates similarity based on kernel SHAP.

Parameters:

  • x_vector (Tensor) –

    The original input tensor of shape (1, N).

  • perturbed_vector (Tensor) –

    The perturbed input tensor of shape (1, N).

  • premise (Premise) –

    The premise object containing the key tensor of shape (1, f).

Returns:

  • Tensor

    torch.Tensor: A tensor of shape [1] containing the similarity score.

Source code in muppet/components/attributor/similarity.py
def kernel_shap_similarity(
    x_vector: torch.Tensor, perturbed_vector: torch.Tensor, premise: "Premise"
) -> torch.Tensor:
    """Calculates similarity based on kernel SHAP.

    Args:
        x_vector (torch.Tensor): The original input tensor of shape (1, N).
        perturbed_vector (torch.Tensor): The perturbed input tensor of shape (1, N).
        premise (Premise): The premise object containing the key tensor of shape (1, f).

    Returns:
        torch.Tensor: A tensor of shape [1] containing the similarity score.
    """
    N = x_vector.shape[1]  # Total number of features
    key_tensor = premise.key.squeeze(0)  # Convert (1, f) to (f,)
    S = (key_tensor == 1).sum().item()  # Number of 1 features in the key

    # Handle case where S is 0 (totally perturbed vector, no similarity)
    if S == 0:
        return torch.tensor([0.0], dtype=torch.float32)

    # Handle case where S equals N (no perturbation, full similarity)
    if S == N:
        return torch.tensor([1.0], dtype=torch.float32)

    # Calculate the combinatorial factor using the formula:
    binomial_coefficient = math.factorial(N) / (
        math.factorial(S) * math.factorial(N - S)
    )
    similarity = (N - 1) / (binomial_coefficient * S * (N - S))

    # Return similarity as a tensor with shape [1]
    return torch.tensor([similarity], dtype=torch.float32)