Source code for MEDfl.rw.model

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

_ACTS = {
    "relu": nn.ReLU,
    "gelu": nn.GELU,
    "leaky_relu": lambda: nn.LeakyReLU(0.01),
    "tanh": nn.Tanh,
    "elu": nn.ELU,
    "silu": nn.SiLU,
    "identity": nn.Identity,
}

def _make_activation(name):
    if name not in _ACTS:
        raise ValueError(f"Unsupported activation '{name}'. Choose from {list(_ACTS)}")
    return _ACTS[name]()

[docs]class Net(nn.Module): """ Flexible MLP with task-aware output. - hidden_dims: e.g., [128, 64, 32] - activation: 'relu' | 'gelu' | 'leaky_relu' | ... - batchnorm: apply BatchNorm1d after each linear (except output) - dropout: float in [0,1] or list per hidden layer - task: 'binary' | 'multiclass' | 'multilabel' | 'regression' - num_classes: required for 'multiclass'; for 'binary' ignore; for 'multilabel' set to label count - return_logits: always True for training (recommended). Use .predict() for post-activation outputs. """ def __init__( self, input_dim, hidden_dims=[64, 32], activation="relu", batchnorm=True, dropout=0.0, task="binary", num_classes=None, output_bias=True, return_logits=True, weight_init="kaiming", ): super().__init__() self.task = task self.num_classes = num_classes self.return_logits = return_logits self.act = activation if task == "multiclass": if not num_classes or num_classes < 2: raise ValueError("For 'multiclass', num_classes >= 2 is required.") output_dim = num_classes elif task == "binary": output_dim = 1 elif task == "multilabel": if not num_classes or num_classes < 1: raise ValueError("For 'multilabel', set num_classes = number of labels.") output_dim = num_classes elif task == "regression": output_dim = 1 else: raise ValueError(f"Unknown task: {task}") # Normalize dropout to list per hidden layer if isinstance(dropout, (int, float)): dropout = [float(dropout)] * len(hidden_dims) elif dropout is None: dropout = [0.0] * len(hidden_dims) elif isinstance(dropout, list): if len(dropout) != len(hidden_dims): raise ValueError("Length of dropout list must match hidden_dims.") else: raise ValueError("dropout must be float|list[float]|None") self.blocks = nn.ModuleList() in_dim = input_dim for h, p in zip(hidden_dims, dropout): block = nn.ModuleDict({ "lin": nn.Linear(in_dim, h, bias=True), "bn": nn.BatchNorm1d(h) if batchnorm else nn.Identity(), "act": _make_activation(activation), "drop": nn.Dropout(p) if p and p > 0 else nn.Identity(), }) self.blocks.append(block) in_dim = h self.out = nn.Linear(in_dim, output_dim, bias=output_bias) # Weight init if weight_init != "none": self.apply(lambda m: self._init_weights(m, scheme=weight_init, activation=activation)) @staticmethod def _init_weights(m, scheme, activation): if isinstance(m, nn.Linear): if scheme == "kaiming": nonlinearity = "leaky_relu" if activation == "leaky_relu" else "relu" nn.init.kaiming_uniform_( m.weight, a=math.sqrt(5) if activation == "leaky_relu" else 0, nonlinearity=nonlinearity, ) elif scheme == "xavier": nn.init.xavier_uniform_(m.weight) if m.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(m.bias, -bound, bound)
[docs] def forward(self, x): for blk in self.blocks: x = blk["lin"](x) x = blk["bn"](x) x = blk["act"](x) x = blk["drop"](x) logits = self.out(x) return logits if self.return_logits else self._apply_output_activation(logits)
# Inference helpers def _apply_output_activation(self, logits): if self.task == "binary": return torch.sigmoid(logits) # (B, 1) elif self.task == "multiclass": return F.softmax(logits, dim=-1) # (B, C) elif self.task == "multilabel": return torch.sigmoid(logits) # (B, L) elif self.task == "regression": return logits # raw regression output else: raise RuntimeError("Invalid task")
[docs] @torch.no_grad() def predict(self, x): out = self.forward(x) probs = self._apply_output_activation(out) if self.return_logits else out if self.task == "binary": return (probs >= 0.5).long() elif self.task == "multiclass": return probs.argmax(dim=-1) elif self.task == "multilabel": return (probs >= 0.5).long() elif self.task == "regression": return probs