Skip to content

Aggregator

muppet.components.aggregator.base

Base aggregator component for MUPPET XAI.

This module defines the abstract base class for all aggregators in the MUPPET XAI framework. Aggregators are responsible for combining individual feature attributions calculated by attributors to produce the final local explanations (heatmaps, feature importance scores, etc.).

The aggregation process is the final step in the four-step perturbation-based XAI approach: generate masks → apply perturbations → calculate attributions → aggregate results.

Classes:

  • Aggregator

    Abstract base class defining the interface for all aggregator components.

Classes

Aggregator
Aggregator()

Bases: ABC

Abstract base class for aggregator components in MUPPET XAI.

A global aggregator component that is responsible for aggregating the attribution calculated by the attributor in order to generate the final heatmap.

Attributes:

  • device

    The used device. Will get updated from the main explainer after initialization.

  • convention

    The attribution convention (constructive or destructive).

  • allowed_conventions

    The allowed convention types from AttributionConvention.

Example

Typical usage involves subclassing the Aggregator base class:

class CustomAggregator(Aggregator):
    def __init__(self, convention="destructive"):
        self.convention = convention
        super().__init__()

    def get_explanation(self, memory):
        # Custom aggregation logic
        return final_heatmap

Initialize the aggregator component.

Sets up the default device and attribution convention for the aggregator. If no convention is set, defaults to 'destructive' with a warning.

Source code in muppet/components/aggregator/base.py
def __init__(self) -> None:
    """Initialize the aggregator component.

    Sets up the default device and attribution convention for the aggregator.
    If no convention is set, defaults to 'destructive' with a warning.
    """
    self.device = DEVICE

    # Check that the convention has been set, if not set it with warning
    if not hasattr(self, "convention"):
        logger.warning(
            "Convention should be set for the Aggregator, by default let set it to 'destructive'"
        )
        self.convention = "destructive"

    super().__init__()
Attributes
convention property writable
convention

Get the aggregation convention.

Functions
reinitialize
reinitialize()

Reset the aggregator to its initial state.

This method restores the aggregator to its original configuration, clearing any internal state or accumulated data that may affect subsequent aggregation operations.

Source code in muppet/components/aggregator/base.py
def reinitialize(self):
    """Reset the aggregator to its initial state.

    This method restores the aggregator to its original configuration,
    clearing any internal state or accumulated data that may affect
    subsequent aggregation operations.
    """
    pass
get_explanation abstractmethod
get_explanation(memory)

A custom method that calculates the final explanations as a heatmap of the same shape as the input. To do so, it uses the stored attributions in memory.

Parameters:

  • memory (Memory) –

    The used memory structure where premises are stored.

Returns:

  • Tensor

    torch.Tensor: The final explanation in the form of a heatmap of shape input.

Source code in muppet/components/aggregator/base.py
@abstractmethod
def get_explanation(
    self,
    memory: Memory,
) -> torch.Tensor:
    """A custom method that calculates the final explanations as a heatmap of the same shape as the input. To do so, it uses the stored attributions in memory.

    Args:
        memory (Memory, optional): The used memory structure where premises are stored.

    Returns:
        torch.Tensor: The final explanation in the form of a heatmap of shape input.

    """
    raise NotImplementedError

muppet.components.aggregator.distribution

Distribution-based aggregator for time series classification explanations.

This module provides aggregators that work with probability distributions over class predictions in time series classification tasks. It implements Monte Carlo aggregation methods to compute feature importance scores from KL divergence measurements between original and perturbed predictions.

The aggregator groups Monte Carlo samples by time steps and features, calculating final attribution scores through statistical aggregation of the KL divergences.

Classes:

  • MonteCarloKLAggregator

    Aggregates attributions using Monte Carlo sampling and KL divergence for time series classification tasks.

Classes

MonteCarloKLAggregator
MonteCarloKLAggregator(num_sampling)

Bases: Aggregator

Monte Carlo aggregator using KL divergence for time series classification.

This aggregator works on probability distributions over classes for classification tasks. It aggregates attributions using Monte Carlo sampling and KL divergence measurements to compute feature importance scores for time series data.

The aggregator groups Monte Carlo samples by time steps and features, calculating final attribution scores through statistical aggregation of KL divergences between original and perturbed predictions.

Initialize the Monte Carlo KL divergence aggregator.

Parameters:

  • num_sampling (int) –

    The number of Monte-Carlo sampling iterations.

Source code in muppet/components/aggregator/distribution.py
def __init__(
    self,
    num_sampling: int,
) -> None:
    """Initialize the Monte Carlo KL divergence aggregator.

    Args:
        num_sampling: The number of Monte-Carlo sampling iterations.
    """
    self.num_sampling = num_sampling
    # check that it is the case
    self.convention = "destructive"
    super().__init__()
Functions
get_explanation
get_explanation(memory)

Calculates the final heatmap by grouping the Monte-Carlo samples premises of the same time-step, then aggregating over their attributions in order to calculate the final time-step's score.

Parameters:

  • memory (Premiselist) –

    List of premises.

Returns:

  • Tensor

    torch.Tensor: The final heatmap where at every timestep the score \(score(t, S)\) is calculated. Shape (b, f, t)

Source code in muppet/components/aggregator/distribution.py
def get_explanation(
    self,
    memory: PremiseList,
) -> torch.Tensor:
    """Calculates the final heatmap by grouping the Monte-Carlo samples premises of the same time-step,
    then aggregating over their attributions in order to calculate the final time-step's score.


    Args:
        memory (Premiselist): List of premises.

    Returns:
        torch.Tensor: The final heatmap where at every timestep the score $score(t, S)$ is calculated. Shape (b, f, t)

    """
    temp_premise = memory.get_premises()[0]

    nb_features, signal_length = temp_premise.key[1]  # (L, (f, t)) => t

    heatmap = torch.zeros((nb_features, signal_length))
    for monte_carlo_premises in self._splitter(
        memory.get_premises(), self.num_sampling
    ):
        kl_divs = torch.stack(
            [premise.attribution for premise in monte_carlo_premises]
        ).to(self.device)  # (num_sampling, b)
        E_div = torch.mean(kl_divs, axis=0)  # (b, t)

        score_timestep = 2.0 / (1 + torch.exp(-5 * E_div)) - 1

        time_step = monte_carlo_premises[0].key[0]["timestep"]
        feature = monte_carlo_premises[0].key[0]["feature"]

        heatmap[feature, time_step] = score_timestep

    return heatmap.unsqueeze(dim=0)  # (b, f, t)

muppet.components.aggregator.mask

Mask-based aggregators for image and spatial data explanations.

This module provides aggregators that work with mask-based perturbations for generating explanation for image classification models. It implements weighted aggregation methods that combine multiple perturbation masks with their corresponding attribution weights to produce final saliency maps.

The module supports both weighted sum aggregation (for methods like RISE) and learned mask aggregation (for gradient-based optimization methods) with support for different attribution conventions (constructive vs destructive).

Classes:

Classes

WeightedSumAggregator
WeightedSumAggregator(
    post_proc=None, convention="destructive"
)

Bases: Aggregator

Weighted sum aggregator for mask-based explanations.

This aggregator multiplies the weight of every perturbation by its mask and sums up all the masks. The weight is equal to the model's class probability. It is commonly used in methods like RISE.

Initialize the weighted sum aggregator.

Parameters:

  • post_proc (Union[Callable[[Tensor], Tensor], None], default: None ) –

    Apply the post_proc function to the calculated heatmap (example: ReLU). None by default, meaning no post processing is done.

  • convention (Union[AttributionConvention, str], default: 'destructive' ) –

    The attribution convention to use, either "constructive" or "destructive".

Source code in muppet/components/aggregator/mask.py
def __init__(
    self,
    post_proc: Union[Callable[[torch.Tensor], torch.Tensor], None] = None,
    convention: Union[AttributionConvention, str] = "destructive",
) -> None:
    """Initialize the weighted sum aggregator.

    Args:
        post_proc: Apply the post_proc function to the calculated heatmap (example: ReLU).
            None by default, meaning no post processing is done.
        convention: The attribution convention to use, either "constructive" or "destructive".
    """
    self.post_proc = post_proc
    self.convention = convention
    super().__init__()
Functions
get_explanation
get_explanation(memory)

Calculate final heatmap by multiplying the weight of every perturbation by its mask and sum up all the masks.

Parameters:

  • memory (Premiselist) –

    A simple list where premises are saved. Every premise provides the attribution where mask's weight is stored.

Returns:

  • Tensor

    torch.Tensor: Final heatmap map of same shape as input x (b=1, c, w, h) highlighting the most important parts of the input example. Where - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment, - w is the width, - h is the height.

Source code in muppet/components/aggregator/mask.py
def get_explanation(
    self,
    memory: PremiseList,
) -> torch.Tensor:
    """Calculate final heatmap by multiplying the weight of every perturbation by its mask and sum up all the masks.

    Args:
        memory (Premiselist): A simple list where premises are saved. Every premise provides the attribution where mask's weight is stored.

    Returns:
        torch.Tensor: Final heatmap map of same shape as input x (b=1, c, w, h) highlighting the most important parts of the input example.
            Where
               - b is batch dimension, expected to be set to 1 as only one example is being explained for the moment,
               - w is the width,
               - h is the height.

    """
    masks = torch.stack(
        [premise.heatmap for premise in memory.get_premises()]
    ).to(self.device)  # (N, b=1, c=1, w, h)
    weights = torch.stack(
        [premise.attribution for premise in memory.get_premises()]
    ).to(self.device)  # (N, b=1)

    example_shape = masks.shape[1:]

    N = masks.size(0)
    b = masks.size(1)  # ==1
    w = masks.size(-2)
    h = masks.size(-1)

    masks = masks.view(N, h * w)  # (N, w*h)
    weights = weights.transpose(0, 1)  # (b=1, N)
    # weights = torch.nn.functional.softmax(weights, dim=-1)

    score_saliency_map = torch.matmul(
        weights, masks
    )  # (b=1, N) x (N, w*h) = (b=1, w*h)
    score_saliency_map = score_saliency_map.view(
        (b, w, h)
    )  # (b=1, w*h) => (b=1, w, h)

    score_saliency_map = score_saliency_map.view(
        *example_shape
    )  # (b=1, c=1, w, h)

    if self.post_proc is not None:
        # score_saliency_map = self.post_proc(score_saliency_map)
        # when post_proc finction return constant values (ex ReLU with only negative values)
        # bypass the post proc function
        score_saliency_map = (
            score_saliency
            if not torch.all(
                (score_saliency := self.post_proc(score_saliency_map)) == 0
            )
            else score_saliency_map
        )
        assert score_saliency_map.shape == (
            b,
            1,
            w,
            h,
        ), "Post processing operation mustn't change saliency map shape."

    min_ssm = score_saliency_map.min()
    max_ssm = score_saliency_map.max()

    if min_ssm == max_ssm:
        return (
            score_saliency_map * 0
        )  # if heatmap is constant, set it to 0 everywhere

    else:
        return (score_saliency_map - min_ssm) / (
            max_ssm - min_ssm
        )  # else, normalise it between 0 and 1
LearntMaskAggregator
LearntMaskAggregator(convention='destructive')

Bases: Aggregator

Learned mask aggregator for optimization-based explanation methods.

This aggregator returns the normalized learned mask from gradient-based optimization methods. It normalizes the mask between 0 and 1 and applies convention-based transformations as needed.

Initialize the learnt mask aggregator.

Parameters:

  • convention (str, default: 'destructive' ) –

    The attribution convention, either 'constructive' or 'destructive'. If "constructive", the heatmap is reversed using 1-heatmap.

Source code in muppet/components/aggregator/mask.py
def __init__(self, convention: str = "destructive") -> None:
    """Initialize the learnt mask aggregator.

    Args:
        convention (str): The attribution convention, either 'constructive' or 'destructive'.
            If "constructive", the heatmap is reversed using 1-heatmap.
    """
    self.convention = convention
    super().__init__()
Functions
get_explanation
get_explanation(memory)

Returns the learnt mask.

Parameters:

  • memory (Premiselist) –

    List of one premise with the mask to be optimized.

Returns:

  • Tensor

    torch.Tensor: The final heatmap. Shape (b=1, *x.shape[1:]).

Source code in muppet/components/aggregator/mask.py
def get_explanation(
    self,
    memory: PremiseList,
) -> torch.Tensor:
    """Returns the learnt mask.

    Args:
        memory (Premiselist): List of one premise with the mask to be optimized.

    Returns:
        torch.Tensor: The final heatmap. Shape (b=1, *x.shape[1:]).

    """
    heatmap = memory.get_premises()[0].mask.detach()  # x.shape
    heatmap = heatmap[0]  # b=1 => (x.shape[1:]) (c=1, w, h)

    # get mask's min and max values
    minn = heatmap.min()
    maxx = heatmap.max()
    max_min = maxx - minn

    heatmap = (heatmap - minn) / max_min

    if self.convention == "constructive":
        heatmap = 1 - heatmap

    return heatmap.unsqueeze(dim=0)  # (b=1, c=1, w, h)

muppet.components.aggregator.model

Model-based aggregators using surrogate models for local explanations.

This module provides aggregators that use surrogate models (like Ridge regression) to fit local linear approximations of the complex model's behavior. This approach is fundamental to LIME-style explanations, where interpretable models explain individual predictions by learning from perturbations in the local neighborhood.

The aggregators support both tabular data (returning coefficients directly) and segmented image data (mapping coefficients back to pixel space using superpixels as in LIME-image). The surrogate models are fitted using weighted samples based on similarity to the original input.

Classes:

Classes

ModelAggregator
ModelAggregator(
    surrogate_model=Ridge(alpha=1, fit_intercept=True),
)

Bases: Aggregator

Base aggregator using surrogate models for local explanations.

This aggregator fits a surrogate model to provide local linear approximations of the complex model's behavior. It is fundamental to LIME-style explanations, where interpretable models explain individual predictions by learning from perturbations in the local neighborhood.

Initialize the model aggregator.

Parameters:

  • surrogate_model

    The model to use in explanation. Defaults to Ridge regression. Must have model_regressor.coef_ and 'sample_weight' as a parameter to model_regressor.fit(). It must be an inherently interpretable model, specifically a ((regularized)(Linear|Logistic)regression) model.

Source code in muppet/components/aggregator/model.py
def __init__(
    self, surrogate_model=Ridge(alpha=1, fit_intercept=True)
) -> None:
    """Initialize the model aggregator.

    Args:
        surrogate_model: The model to use in explanation. Defaults to Ridge regression.
            Must have model_regressor.coef_ and 'sample_weight' as a parameter
            to model_regressor.fit(). It must be an inherently interpretable model,
            specifically a ((regularized)(Linear|Logistic)regression) model.
    """
    self.surrogate_model = surrogate_model
    self.convention = "perturbed_input_similarity"
    super().__init__()
Functions
fit
fit(list_premises)

Fits the surrogate model with premises.

Source code in muppet/components/aggregator/model.py
def fit(self, list_premises: List[Premise]) -> None:
    """Fits the surrogate model with premises."""
    assert "similarity" in list_premises[0].attribution.keys(), (
        "The attribution attribute for the premises must be a dictionary with this format {'attribution': attribution ; 'similarity':similarity}"
    )

    keys, perturbed_scores, similarities = zip(
        *[
            (
                self._prepare_key(p),
                p.attribution["attribution"],
                p.attribution["similarity"],
            )
            for p in list_premises
        ]
    )
    # for p in list_premises:
    #     print(f'{self._prepare_key(p)}, {p.attribution["similarity"] }, {p.attribution["attribution"]}')

    perturbed_scores = (
        torch.stack(perturbed_scores).squeeze(dim=1).to(self.device)
    )
    keys = torch.stack(keys).to(self.device)

    try:
        self.surrogate_model.fit(
            keys.numpy(),
            perturbed_scores.numpy(),
            sample_weight=similarities,
        )
    except TypeError:
        self.surrogate_model.fit(
            keys.cpu().numpy(),
            perturbed_scores.cpu().numpy(),
            sample_weight=similarities,
        )
get_coefs
get_coefs(memory)

Method which fits a linear model with the data contained in premises and returns the learned coefficients of the surrogate model

Source code in muppet/components/aggregator/model.py
def get_coefs(self, memory: PremiseList) -> torch.Tensor:
    """Method which fits a linear model with the data contained in premises
    and returns the learned coefficients of the surrogate model
    """
    list_premises = memory.get_premises()
    self.fit(list_premises)
    return torch.tensor(self.surrogate_model.coef_)
get_explanation
get_explanation(memory)

Calculate final heatmap.

This method is meant to be overridden by subclasses to handle different types of data.

Parameters:

  • memory (Premiselist) –

    A Premiselist where premises are saved.

Returns:

  • Tensor

    torch.Tensor: The heatmap or coefficients depending on the data type.

Source code in muppet/components/aggregator/model.py
def get_explanation(self, memory: PremiseList) -> torch.Tensor:
    """Calculate final heatmap.

    This method is meant to be overridden by subclasses to handle different types of data.

    Args:
        memory (Premiselist): A Premiselist where premises are saved.

    Returns:
        torch.Tensor: The heatmap or coefficients depending on the data type.
    """
    # By default, return coefficients directly for tabular data
    return self.get_coefs(memory).clone().detach().unsqueeze(0)
SegmentedImageModelAggregator
SegmentedImageModelAggregator(
    surrogate_model=Ridge(alpha=1, fit_intercept=True),
)

Bases: ModelAggregator

Specialized aggregator for image data using superpixels and surrogate models.

This aggregator extends ModelAggregator to handle segmented image data by mapping surrogate model coefficients back to pixel space using superpixels. It transforms the learned feature importance values from the superpixel level back to a spatial heatmap that highlights important image regions.

Source code in muppet/components/aggregator/model.py
def __init__(
    self, surrogate_model=Ridge(alpha=1, fit_intercept=True)
) -> None:
    """Initialize the model aggregator.

    Args:
        surrogate_model: The model to use in explanation. Defaults to Ridge regression.
            Must have model_regressor.coef_ and 'sample_weight' as a parameter
            to model_regressor.fit(). It must be an inherently interpretable model,
            specifically a ((regularized)(Linear|Logistic)regression) model.
    """
    self.surrogate_model = surrogate_model
    self.convention = "perturbed_input_similarity"
    super().__init__()
Functions
get_explanation
get_explanation(memory)

Calculate final heatmap for segmented image data.

Parameters:

  • memory (Premiselist) –

    A Premiselist where premises are saved. Every premise provides the attribution where mask's weight is stored.

Returns:

  • Tensor

    torch.Tensor: Final heatmap map of same shape as input x (b=1, c=1, h, w) highlighting the most important parts of the input example.

  • Tensor

    Where

  • Tensor

    b is batch dimension, expected to be set to 1 as only one example is being explained for the moment,

  • Tensor

    w is the width,

  • Tensor

    h is the height.

Source code in muppet/components/aggregator/model.py
def get_explanation(self, memory: PremiseList) -> torch.Tensor:
    """Calculate final heatmap for segmented image data.

    Args:
        memory (Premiselist): A Premiselist where premises are saved. Every premise provides the attribution where mask's weight is stored.

    Returns:
        torch.Tensor: Final heatmap map of same shape as input x (b=1, c=1, h, w) highlighting the most important parts of the input example.

        Where
        b is batch dimension, expected to be set to 1 as only one example is being explained for the moment,

        w is the width,

        h is the height.

    """
    segmented_example = memory.get_premises()[0].segmented_example.to(
        self.device
    )
    coefs = self.get_coefs(memory).clone().detach().to(self.device)

    s, h, w = segmented_example.shape
    heatmap = torch.matmul(
        coefs,
        segmented_example.view(
            s,
            h * w,
        ).double(),  # (1, s)*(s, h * w) => (1, h*w)
    )
    heatmap = heatmap.unsqueeze(0)
    min_values = heatmap.min(dim=1).values
    max_values = heatmap.max(dim=1).values
    mask_diff_max_min = min_values != max_values
    heatmap[~mask_diff_max_min] = 0

    # Adds a dimension for broadcasting
    min_values = min_values.unsqueeze(dim=1)
    max_values = max_values.unsqueeze(dim=1)

    heatmap[mask_diff_max_min] = (
        ((heatmap - min_values)[mask_diff_max_min])
        / (max_values - min_values)[mask_diff_max_min]
    )

    heatmap = heatmap.view(
        1, 1, *segmented_example.shape[1:]
    )  # (1, 1, h, w)
    return heatmap