Source code for autoqild.mi_estimators.pytorch_utils

"""Utilities for running the PC-softmax and Mine MI estimator, like loss
functions and optimizers."""

import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam, RMSprop, SGD, Adagrad, Adamax, AdamW, Adadelta

optimizers = {
    "RMSprop": RMSprop,
    "sgd": SGD,
    "adam": Adam,
    "AdamW": AdamW,
    "Adagrad": Adagrad,
    "Adamax": Adamax,
    "Adadelta": Adadelta,
}
optimizer_parameters = {
    "RMSprop": {
        "lr": 0.01,
        "alpha": 0.99,
        "eps": 1e-08,
        "weight_decay": 0,
        "momentum": 0,
        "centered": False,
    },
    "sgd": {"lr": 0.001, "momentum": 0.7, "weight_decay": 0},
    "adam": {"lr": 1e-4, "betas": (0.5, 0.999), "weight_decay": 0, "amsgrad": False},
    "AdamW": {
        "lr": 1e-4,
        "betas": (0.5, 0.999),
        "eps": 1e-08,
        "weight_decay": 0.01,
        "amsgrad": False,
    },
    "Adagrad": {
        "lr": 0.01,
        "lr_decay": 0,
        "weight_decay": 0,
        "initial_accumulator_value": 0,
        "eps": 1e-10,
    },
    "Adamax": {"lr": 0.002, "betas": (0.9, 0.999), "eps": 1e-08, "weight_decay": 0},
    "Adadelta": {"lr": 1.0, "rho": 0.9, "eps": 1e-06, "weight_decay": 0},
}


[docs] def get_optimizer_and_parameters(optimizer_str, learning_rate, reg_strength): """Get the optimizer and its configuration parameters based on the specified optimizer string. Parameters ---------- optimizer_str : {`RMSprop`, `sgd`, "adam", `AdamW`, `Adagrad`, `Adamax`, `Adadelta`}, default="adam" Optimizer type to use for training the neural network. Must be one of: - `RMSprop`: Root Mean Square Propagation, an adaptive learning rate method. - `sgd`: Stochastic Gradient Descent, a simple and widely-used optimizer. - "adam": Adaptive Moment Estimation, combining momentum and RMSProp for better convergence. - `AdamW`: Adam with weight decay, an improved variant of Adam with better regularization. - `Adagrad`: Adaptive Gradient Algorithm, adjusting the learning rate based on feature frequency. - `Adamax`: Variant of Adam based on infinity norm, more robust with sparse gradients. - `Adadelta`: An extension of Adagrad that seeks to reduce its aggressive learning rate decay. learning_rate : float The learning rate for the optimizer. reg_strength : float The regularization strength (weight decay) for the optimizer. Returns ------- optimizer : torch.optim.Optimizer The optimizer class. optimizer_config : dict The configuration parameters for the optimizer. Raises ------ ValueError If the specified optimizer string is not recognized. """ optimizer = optimizers.get(optimizer_str, "adam") optimizer_config = optimizer_parameters.get(optimizer_str, "adam") optimizer_config["lr"] = learning_rate optimizer_config["weight_decay"] = reg_strength return optimizer, optimizer_config
[docs] def init(m): """Initialize the weights and biases of a neural network layer. Parameters ---------- m : torch.nn.Module The neural network layer to initialize. Notes ----- This function initializes the weights of a linear layer using orthogonal initialization and sets the biases to zero. """ if type(m) == nn.Linear: nn.init.orthogonal_(m.weight) if hasattr(m, "bias"): m.bias.data.fill_(0.0)
[docs] def log_mean_exp(inputs, dim=None, keepdim=False): """Compute the log of the mean of the exponentials of input elements. Parameters ---------- inputs : torch.Tensor Input tensor. dim : int or tuple of ints, optional The dimension or dimensions to reduce. If None, reduces all dimensions. keepdim : bool, optional Whether the output tensor has dim retained or not. Returns ------- outputs : torch.Tensor The logarithm of the mean of the exponentials of the input tensor. """ if dim is None: inputs = inputs.view(-1) dim = 0 s, _ = torch.max(inputs, dim=dim, keepdim=True) outputs = s + (inputs - s).exp().mean(dim=dim, keepdim=True).log() if not keepdim: outputs = outputs.squeeze(dim) return outputs
[docs] def get_mine_loss(preds_xy, preds_xy_tilde, metric): """Calculate the MINE loss based on the specified metric. Parameters ---------- preds_xy : torch.Tensor Predictions for the joint distribution samples. preds_xy_tilde : torch.Tensor Predictions for the product of marginals distribution samples. metric : {`donsker_varadhan`, `donsker_varadhan_softplus`, `fdivergence`} The divergence metric to use for the MINE loss. Options include: - `donsker_varadhan`: Donsker-Varadhan representation of KL divergence. - `donsker_varadhan_softplus`: Softplus version of the Donsker-Varadhan representation. - `fdivergence`: f-divergence representation of mutual information. Returns ------- loss : torch.Tensor Calculated MINE loss based on the specified metric. Raises ------ ValueError If the specified metric is not recognized. """ SMALL = 1e-8 if metric == "donsker_varadhan": loss = preds_xy.mean(dim=0) - log_mean_exp(preds_xy_tilde, dim=0) loss = loss * torch.log2(torch.exp(torch.tensor(1.0))) return loss elif metric == "donsker_varadhan_softplus": loss = torch.log(F.softplus(preds_xy) + SMALL).mean(dim=0) - torch.log( F.softplus(preds_xy_tilde).mean(dim=0) + SMALL ) loss = loss * torch.log2(torch.exp(torch.tensor(1.0))) return loss elif metric == "fdivergence": loss = preds_xy.mean(dim=0) - torch.exp(preds_xy_tilde - 1).mean(dim=0) loss = loss * torch.log2(torch.exp(torch.tensor(1.0))) return loss else: err_msg = f"unrecognized metric {metric}" raise ValueError(err_msg)
[docs] def own_softmax(x, label_proportions, device): """Custom softmax function that incorporates label proportions to handle imbalanced data. This function computes a modified softmax, where the exponentiated logits are weighted by the proportions of each class label. This can help in cases where class imbalance is significant, ensuring that the model accounts for the distribution of labels during prediction. Parameters ---------- x : torch.Tensor The input tensor (logits) of shape `(n_samples, n_classes)`. label_proportions : list, numpy.ndarray, or torch.Tensor The proportions of each class in the dataset. This should be a list or tensor of shape `(n_classes,)` representing the proportion of each class in the dataset. device : torch.device The device on which to perform the computation (e.g., `cpu` or `cuda`). Returns ------- torch.Tensor The resulting tensor after applying the weighted softmax operation, of shape `(n_samples, n_classes)`. Notes ----- This function first exponentiates the logits (`x`) and then multiplies them by the corresponding class proportions (`label_proportions`). The resulting tensor is normalized by the sum of the weighted exponentiated logits to produce a probability distribution across classes. """ if not isinstance(label_proportions, torch.Tensor): label_proportions = torch.tensor(label_proportions).to(device) x_exp = torch.exp(x) weighted_x_exp = x_exp * label_proportions x_exp_sum = torch.sum(weighted_x_exp, 1, keepdim=True) return x_exp / x_exp_sum