Models
muppet.benchmark.models.base
Base model classes for the MUPPET benchmark framework.
This module provides abstract base classes for creating unified model wrappers that support different frameworks (PyTorch, scikit-learn, XGBoost) with a consistent API. It enables seamless integration across various model types through standardized training and inference interfaces.
Classes:
-
AbstractModel–Abstract base class for unified model wrappers
Classes
AbstractModel
Bases: Module, ABC
A flexible model wrapper that unifies training and inference for models from different frameworks: PyTorch, scikit-learn, and XGBoost.
This class supports model configuration through Hydra. If the provided model config contains
a _target_ key (Hydra-style), it loads and instantiates the model dynamically. It allows both
standard training and Optuna-based hyperparameter optimization (for non-PyTorch models).
It enables seamless integration across various model types through standardized training and inference interfaces.
It behaves like a standard PyTorch module, with .fit() for training and .forward() for inference,
enabling easy integration into a PyTorch-style pipeline.
Initialize the AbstractModel wrapper instance.
Parameters:
-
model(Union[DictConfig, Any]) –Model configuration or model instance.
-
name(str) –Name of the model.
-
pretrained(bool, default:True) –Whether to use a pretrained version (if supported). Default is True.
-
postprocessing_func(Callable, default:lambda x: x) –A function to apply to the model's raw outputs. Defaults to a no-op lambda function.
Source code in muppet/benchmark/models/base.py
Functions
abstractmethod
Trains the model on the provided data.
This method must be implemented by subclasses. It should handle the model's training loop and optimization.
Parameters:
-
train_loader(DataLoader) –The data loader containing training samples.
Source code in muppet/benchmark/models/base.py
abstractmethod
Performs a forward pass through the model for inference.
This method must be implemented by subclasses. It defines the model's behavior during inference.
Parameters:
-
x(Tensor) –The input tensor.
Returns:
-
–
torch.Tensor: The output tensor from the model.
Source code in muppet/benchmark/models/base.py
Runs inference using the wrapped model on the provided DataLoader.
This function processes the input data in batches, performs a forward pass through the model without computing gradients, and collects both the input data and the predicted outputs after post-processing.
Parameters:
-
dataloader(DataLoader) –A DataLoader providing batches of input data.
Returns:
-
tuple[ndarray, ndarray]–Tuple[np.ndarray, np.ndarray]: A tuple containing: - inputs (np.ndarray): The concatenated input data as a NumPy array. - predictions (np.ndarray): The concatenated post-processed predictions as a NumPy array.
Source code in muppet/benchmark/models/base.py
muppet.benchmark.models.classifier
Classification model wrappers for the MUPPET benchmark framework.
This module provides unified wrapper classes for classification models across different machine learning frameworks (PyTorch, scikit-learn, XGBoost) with a consistent interface for training and inference.
Classes:
-
Classifier–Unified wrapper for classification models with hyperparameter tuning
-
GRUClassifier–PyTorch GRU-based classifier for sequence/time series data
Classes
Classifier
Classifier(
model,
name,
pretrained=True,
random_state=105,
n_trials=20,
cv=5,
scoring="balanced_accuracy",
study_name="classifier",
predict_proba_func="predict_log_proba",
params_tuning=None,
n_jobs=1,
)
Bases: AbstractModel
A unified wrapper for classification models, providing a consistent API for models from different frameworks like PyTorch, scikit-learn, and XGBoost.
This class handles both standard training and hyperparameter optimization
(via Optuna for non-PyTorch models). It dynamically loads models from
Hydra-style configurations and provides a PyTorch-like interface with
.fit() and .forward() methods for seamless pipeline integration.
It supports hyperparameter optimization via Optuna and provides a consistent
interface for training and inference.
Initialize the Classifier wrapper instance.
Parameters:
-
model(Union[DictConfig, Any]) –The model configuration (e.g., from Hydra) or an already instantiated model object.
-
name(str) –The name of the model.
-
pretrained(bool, default:True) –Whether to use a pretrained version of the model. Defaults to True.
-
random_state(int, default:105) –The random state for reproducibility, particularly in hyperparameter tuning. Defaults to 105.
-
n_trials(int, default:20) –The number of trials for Optuna-based hyperparameter search. Defaults to 20.
-
cv(int, default:5) –The number of cross-validation folds for hyperparameter tuning. Defaults to 5.
-
scoring(str, default:'balanced_accuracy') –The scoring function used to evaluate models during tuning. Defaults to "balanced_accuracy".
-
study_name(str, default:'classifier') –The name for the Optuna study. Defaults to "classifier".
-
predict_proba_func(str, default:'predict_log_proba') –The name of the method to call on non-PyTorch models to get probabilistic outputs. Defaults to "predict_log_proba".
-
params_tuning(dict[str, Any] | None, default:None) –A dictionary defining the hyperparameter search space for Optuna. If
None, no tuning is performed. Defaults to None. -
n_jobs(int, default:1) –The number of parallel jobs for hyperparameter tuning. Defaults to 1.
Source code in muppet/benchmark/models/classifier.py
Functions
Trains the classification model using the provided training data.
For PyTorch models, this method currently raises a NotImplementedError, as the training
logic needs to be added (e.g., using a training loop or a library like PyTorch Lightning).
For scikit-learn and XGBoost models, it extracts data from the DataLoader, converts it to
NumPy arrays, and either fits the model directly or, if params_tuning is provided,
performs Optuna-based hyperparameter tuning and uses the best-found estimator.
Parameters:
-
train_loader(DataLoader) –The data loader containing the training samples.
Source code in muppet/benchmark/models/classifier.py
Performs a forward pass for inference, returning probabilistic outputs.
For PyTorch models, this method calls the model directly. For scikit-learn
and XGBoost models, it converts the input tensor to a NumPy array, calls the
method specified by predict_proba_func (e.g., .predict_log_proba), and
converts the results back into a PyTorch tensor. This ensures a consistent
tensor output regardless of the underlying framework.
Parameters:
-
x(Tensor) –The input tensor.
Returns:
-
–
torch.Tensor: A tensor containing the probabilistic outputs.
Source code in muppet/benchmark/models/classifier.py
GRUClassifier
GRUClassifier(
feature_size,
n_state=2,
hidden_size=200,
rnn="GRU",
regres=True,
bidirectional=False,
return_all=False,
pretrained_model_path=None,
)
Bases: Module
GRU-based neural network classifier for sequential data.
A recurrent neural network classifier using GRU (Gated Recurrent Unit) layers for processing sequential data with configurable hidden dimensions and output classes. Sets up a recurrent neural network with GRU or LSTM layers for processing time series or sequential data, with configurable architecture parameters.
Initialize the GRUClassifier instance.
Parameters:
-
feature_size–Number of input features per time step.
-
n_state–Number of output classes for classification.
-
hidden_size–Dimension of the hidden state in RNN layers.
-
rnn–Type of RNN to use ("GRU" or "LSTM").
-
regres–If True, adds regression/classification head on top of RNN.
-
bidirectional–If True, uses bidirectional RNN processing.
-
return_all–If True, returns outputs from all time steps; if False, only returns the final time step output.
-
pretrained_model_path(str, default:None) –Optional path to load pretrained model weights.
Source code in muppet/benchmark/models/classifier.py
Functions
Forward pass through the RNN classifier.
Source code in muppet/benchmark/models/classifier.py
muppet.benchmark.models.segmentation
Segmentation model wrappers for the MUPPET benchmark framework.
This module provides unified wrapper classes for semantic segmentation models with a consistent PyTorch-based interface for loading, training, and inference.
Classes:
-
SegmentationModel–Unified wrapper for semantic segmentation models
Classes
SegmentationModel
Bases: AbstractModel
A unified wrapper for semantic segmentation models, providing a consistent API for PyTorch models.
This class handles model loading from Hydra-style configurations and offers a
PyTorch-like interface with .fit() and .forward() methods, enabling
seamless pipeline integration for segmentation tasks. It focuses specifically on computer vision segmentation tasks.
Since this implementation currently only supports PyTorch models, it omits the hyperparameter tuning logic present in other model wrappers.
Initialize the SegmentationModel wrapper instance.
Parameters:
-
model(Union[DictConfig, Any]) –The model configuration (e.g., from Hydra) or an already instantiated PyTorch model object.
-
name(str) –The name of the model.
-
pretrained(bool, default:True) –Whether to use a pretrained version of the model. Defaults to True.
Note
The postprocessing_func is automatically set to torch.softmax(x, dim=1)
to convert the model's raw logits into class probabilities for each pixel.
Source code in muppet/benchmark/models/segmentation.py
Functions
Trains or Fine-tunes the segmentation model using the provided training data.
This method currently raises a NotImplementedError, as the training loop
logic for PyTorch models needs to be implemented.
Parameters:
-
train_loader(DataLoader) –The data loader containing the training samples.
Source code in muppet/benchmark/models/segmentation.py
Performs a forward pass for inference on the segmentation model.
This method calls the PyTorch model directly and extracts the 'out' key from the output dictionary, which is a common pattern for segmentation models from torchvision or similar libraries.
Parameters:
-
x(Tensor) –The input tensor, typically an image.
Returns:
-
–
torch.Tensor: The output tensor from the model, representing the
-
–
segmentation map.