Source code for Heimdall.trainer

"""Heimdall trainer."""

import random
from collections import OrderedDict, defaultdict
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
from typing import Callable

import numpy as np
import pandas as pd
import psutil
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from omegaconf import OmegaConf
from torchmetrics.classification import (
    Accuracy,
    ConfusionMatrix,
    F1Score,
    MatthewsCorrCoef,
    Precision,
    Recall,
)
from torchmetrics.regression import MeanSquaredError, R2Score
from tqdm import tqdm
from transformers import get_scheduler

import Heimdall.datasets
import Heimdall.losses
import wandb
from Heimdall.cell_representations import setup_data
from Heimdall.models import TransformerOutput, setup_model
from Heimdall.utils import (  # get_cached_paths,
    INPUT_KEYS,
    get_fully_qualified_cache_paths,
    instantiate_from_config,
    project2simplex_,
    save_umap,
)

# from Heimdall.cell_representations import PartitionedCellRepresentation


[docs] class HeimdallTrainer: CHECKPOINT_KEYS = ("tasks", "fg", "fe", "fc", "model", "dataset.preprocess_args.data_path") def __init__( self, cfg, model, data, accelerator: Accelerator, batchsize: int, epochs: int, random_seed: int = 0, early_stopping: bool = False, early_stopping_patience: int = 5, accumulate_grad_batches: int = 1, grad_norm_clip: float = 1.0, skip_umaps: bool = True, fastdev: bool = False, # if set to true, then only train/evel/test on the first batch run_wandb=False, custom_loss_func=None, custom_metrics=None, ): self.cfg = cfg self.model = model self.data = data self.accelerator = accelerator self.has_embeddings = self.model_cfg.name != "logistic_regression" self.check_flash_attn() # TODO: since we use the label_key in the CellRepresentation setup, we shouldn't need it here. # It should all be accessible in the data.labels... Delete the block below if possible...? # Unified label key handling: support .obs or .obsm self.setup_class_names_and_num_labels(data) self.batchsize = batchsize self.epochs = epochs self.random_seed = random_seed self.early_stopping = early_stopping self.early_stopping_patience = early_stopping_patience self.accumulate_grad_batches = accumulate_grad_batches self.grad_norm_clip = grad_norm_clip self.skip_umaps = skip_umaps self.fastdev = fastdev self.run_wandb = run_wandb self.process = psutil.Process() self.custom_loss_func = custom_loss_func self.custom_metrics = custom_metrics or {} set_seed(cfg.seed) self.optimizer = self._initialize_optimizer() self.loss_functions = self.instantiate_loss_functions_from_config() self.accelerator.wait_for_everyone() self.print_r0(f"> Using Device: {self.accelerator.device}") self.print_r0(f"> Number of Devices: {self.accelerator.num_processes}") self.best_val_outputs = defaultdict(dict) self.best_test_outputs = defaultdict(dict) self.best_epoch = defaultdict(int) self._initialize_wandb() self._initialize_lr_scheduler() self.step = 0 ( self.model, self.optimizer, self.lr_scheduler, ) = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler, ) self.print_r0("> Finished Wrapping the model, optimizer, and dataloaders in accelerate") self.print_r0("> run HeimdallTrainer.train() to begin training") @property def local_cfg(self): return self.data.local_cfg @property def model_cfg(self): return self.cfg.scfm.model @property def dataset_cfg(self): return self.local_cfg.dataset @property def trainer_cfg(self): return self.local_cfg.trainer @property def optimizer_cfg(self): return self.local_cfg.optimizer @property def scheduler_cfg(self): return self.local_cfg.scheduler @property def fg_cfg(self): return self.data.fg_cfg @property def fe_cfg(self): return self.data.fe_cfg @property def fc_cfg(self): return self.data.fc_cfg
[docs] def check_flash_attn(self): if ( hasattr(self.model.encoder.cell_sentence_model, "use_flash_attn") and self.model.encoder.cell_sentence_model.use_flash_attn and self.accelerator.mixed_precision != "bf16" ): raise ValueError("If using Flash Attention, mixed precision must be bf16")
[docs] def setup_class_names_and_num_labels(self, data): self.class_names = {} for subtask_name, subtask in data.tasklist: label_key = subtask.label_col_name label_obsm_key = subtask.label_obsm_name if subtask.task_type in ("multiclass", "binary"): if label_key is not None: # Single-label classification using .obs[label_key] if not pd.api.types.is_categorical_dtype(data.adata.obs[label_key]): data.adata.obs[label_key] = data.adata.obs[label_key].astype("category") self.class_names[subtask_name] = data.adata.obs[label_key].cat.categories.tolist() elif label_obsm_key is not None: self.class_names[subtask_name] = data.adata.obsm[label_obsm_key].columns.tolist() else: self.class_names[subtask_name] = data.adata.uns["task_order"] # NOTE: first entry might be NULL self.num_labels = {} for subtask_name, subtask in data.tasklist: label_key = subtask.label_col_name label_obsm_key = subtask.label_obsm_name if subtask.task_type in ("multiclass", "binary") and (label_key or label_obsm_key): self.num_labels[subtask_name] = len(self.class_names[subtask_name]) else: self.num_labels[subtask_name] = subtask.num_tasks
@property def data(self): return self._data @data.setter def data(self, data): self._data = data for split in ["train", "val", "test", "full"]: setattr(self, f"dataloader_{split}", data.dataloaders[split]) @property def save_precomputed(self): return self.data._save_precomputed @save_precomputed.setter def save_precomputed(self, val): self.data._save_precomputed = val @property def get_precomputed(self): return self.data._get_precomputed @get_precomputed.setter def get_precomputed(self, val): self.data._get_precomputed = val
[docs] def print_r0(self, message): self.data.print_r0(message)
def _initialize_optimizer(self): optimizer_class = getattr(torch.optim, self.optimizer_cfg.name) return optimizer_class( self.model.parameters(), **OmegaConf.to_container(self.optimizer_cfg.args), ) def _initialize_wandb(self, **wandb_kwargs): if self.run_wandb and self.accelerator.is_main_process: print("==> Starting a new WANDB run") new_tags = ( self.dataset_cfg.dataset_name, self.fg_cfg.type, self.fe_cfg.type, self.fc_cfg.type, ) wandb_config = { "wandb": { "tags": new_tags, "name": self.cfg.run_name, "entity": self.cfg.entity, **wandb_kwargs, }, } self.accelerator.init_trackers( project_name=self.cfg.project_name, config=OmegaConf.to_container(self.cfg, resolve=True), init_kwargs=wandb_config, ) print("==> Initialized Run") def _initialize_lr_scheduler(self): global_batch_size = self.batchsize total_steps = len(self.dataloader_train.dataset) // global_batch_size * self.epochs warmup_ratio = self.scheduler_cfg.warmup_ratio warmup_step = int(warmup_ratio * total_steps) self.total_training_steps = total_steps self.warmup_steps = warmup_step self.lr_scheduler = get_scheduler( name=self.scheduler_cfg.name, optimizer=self.optimizer, num_warmup_steps=warmup_step, num_training_steps=total_steps, ) self.print_r0("!!! Remember that config batchsize here is GLOBAL Batchsize !!!") self.print_r0(f"> global batchsize: {global_batch_size}") self.print_r0(f"> total_samples: {len(self.dataloader_train.dataset)}") self.print_r0(f"> Warm Up Steps: {warmup_step}") self.print_r0(f"> Total Steps: {total_steps}") self.print_r0(f"> per_device_batch_size: {global_batch_size // self.accelerator.num_processes}") def _initialize_metrics(self): """Initializing the metrics based on the hydra config.""" metrics = defaultdict(dict) for subtask_name, subtask in self.data.tasklist: subtask_metrics = metrics[subtask_name] task_type = subtask.task_type # First, add custom metrics if provided, TODO this is not implemented yet assert self.custom_metrics == {}, "Custom metrics not implemented yet" subtask_metrics.update(self.custom_metrics) # Then, add built-in metrics if not overridden by custom metrics if task_type in ("mlm", "multiclass"): num_classes = self.num_labels[subtask_name] for metric_name in subtask.metrics: if metric_name not in metrics: if metric_name == "Accuracy": subtask_metrics[metric_name] = Accuracy(task="multiclass", num_classes=num_classes) if subtask.top_k is not None: for k in subtask.top_k: subtask_metrics[f"{metric_name}_top_{k}"] = Accuracy( task="multiclass", num_classes=num_classes, top_k=k, ) elif metric_name == "Precision": subtask_metrics[metric_name] = Precision( task="multiclass", num_classes=num_classes, average="macro", ) elif metric_name == "Recall": subtask_metrics[metric_name] = Recall( task="multiclass", num_classes=num_classes, average="macro", ) elif metric_name == "F1Score": subtask_metrics[metric_name] = F1Score( task="multiclass", num_classes=num_classes, average="macro", ) elif metric_name == "MatthewsCorrCoef": subtask_metrics[metric_name] = MatthewsCorrCoef(task="multiclass", num_classes=num_classes) elif metric_name == "ConfusionMatrix" and task_type != "mlm": subtask_metrics[metric_name] = ConfusionMatrix(task="multiclass", num_classes=num_classes) elif task_type == "regression": for metric_name in subtask.metrics: if metric_name not in metrics: if metric_name == "R2Score": subtask_metrics[metric_name] = R2Score() elif metric_name == "MSE": subtask_metrics[metric_name] = MeanSquaredError() elif task_type == "binary": # num_labels = self.num_labels num_labels = 2 for metric_name in subtask.metrics: if metric_name not in metrics: if metric_name == "Accuracy": subtask_metrics[metric_name] = Accuracy(task="binary", num_labels=num_labels) elif metric_name == "Precision": subtask_metrics[metric_name] = Precision( task="binary", num_labels=num_labels, average="macro", ) elif metric_name == "Recall": subtask_metrics[metric_name] = Recall(task="binary", num_labels=num_labels, average="macro") elif metric_name == "F1Score": subtask_metrics[metric_name] = F1Score( task="binary", num_labels=num_labels, average="macro", ) elif metric_name == "MatthewsCorrCoef": subtask_metrics[metric_name] = MatthewsCorrCoef(task="binary", num_labels=num_labels) metrics[subtask_name] = { k: v.to(self.accelerator.device) if hasattr(v, "to") else v for k, v in subtask_metrics.items() } return metrics
[docs] def fit(self, get_precomputed=False, **fit_kwargs): context = nullcontext() if get_precomputed: context = PrecomputationContext( self, save_precomputed=False, get_precomputed=get_precomputed, run_wandb=True, ) with context: return self.fit_model(**fit_kwargs)
[docs] def fit_model( self, resume_from_checkpoint=True, checkpoint_every_n_epochs=1, precompute_last_epoch=False, do_cleanup=True, start_epoch=0, ): """Train the model with automatic checkpointing and resumption.""" # Try to resume from checkpoint if requested if resume_from_checkpoint: start_epoch = self.load_checkpoint() if start_epoch >= self.epochs: # last_epoch = max(0, start_epoch - 1) # Run one eval pass on the loaded weights to get embeddings # _, val_outputs = self.validate_model(self.dataloader_val, "valid") # _, test_outputs = self.validate_model(self.dataloader_test, "test") if self.accelerator.is_main_process and self.model_cfg.name != "logistic_regression": # self.save_adata_umap(test_embed, val_embed) # self.print_r0(f"> Saved UMAP from checkpoint epoch {last_epoch}") pass return # If the tracked parameter is specified best_metric = defaultdict(dict) for subtask_name, subtask in self.data.tasklist: if subtask.track_metric is not None: best_metric[subtask_name] = defaultdict(lambda: float("-inf")) assert ( subtask.track_metric in subtask.metrics ), "The tracking metric is not in the list of metrics, please check your configuration task file" # Initialize early stopping parameters early_stopping = self.early_stopping early_stopping_patience = self.early_stopping_patience patience_counter = defaultdict(int) def fit_epoch(epoch: int): # Validation and test evaluation valid_log, val_outputs = self.validate_model(self.dataloader_val, dataset_type="valid") test_log, test_outputs = self.validate_model(self.dataloader_test, dataset_type="test") # Track the best metric if specified reset_patience_counter = False for subtask_name, subtask in self.data.tasklist: if subtask.track_metric is not None: val_metric = valid_log.get(f"valid_{subtask_name}_{subtask.track_metric}", float("-inf")) if ( val_metric > best_metric[subtask_name][f"best_val_{subtask_name}_{subtask.track_metric}"] ): # Change to >= if you want to debug UMAP for key in val_outputs: self.best_val_outputs[key][subtask_name] = val_outputs[key][subtask_name] self.best_test_outputs[key][subtask_name] = test_outputs[key][subtask_name] self.best_epoch[subtask_name] = epoch best_metric[subtask_name][f"best_val_{subtask_name}_{subtask.track_metric}"] = val_metric self.print_r0(f"New best validation for {subtask_name} {subtask.track_metric}: {val_metric}") best_metric[subtask_name]["reported_epoch"] = epoch # log the epoch for convenience for metric in subtask.metrics: best_metric[subtask_name][f"reported_test_{metric}"] = test_log.get( f"test_{subtask_name}_{metric}", float("-inf"), ) reset_patience_counter = True # Save checkpoint for best model self.save_checkpoint(epoch) self.print_r0(f"> Saved best model checkpoint at epoch {epoch}") else: for key in val_outputs: self.best_val_outputs[key][subtask_name] = val_outputs[key][subtask_name] self.best_test_outputs[key][subtask_name] = test_outputs[key][subtask_name] self.best_epoch[subtask_name] = epoch if reset_patience_counter: patience_counter[subtask_name] = 0 # Reset patience counter since we have a new best else: patience_counter[subtask_name] += 1 if early_stopping: self.print_r0( f"No improvement in validation {subtask.track_metric}. " f"Patience counter: {patience_counter[subtask_name]}/{early_stopping_patience}", ) # Check early stopping condition if early_stopping and patience_counter[subtask_name] >= early_stopping_patience: self.print_r0( f"Early stopping triggered. No improvement in {subtask.track_metric} for " f"{early_stopping_patience} epochs.", ) return True # Train for one epoch self.train_epoch(epoch) # Save checkpoint at regular intervals if requested if (epoch + 1) % checkpoint_every_n_epochs == 0: self.save_checkpoint(epoch) self.print_r0(f"> Saved regular checkpoint at epoch {epoch}") return False for epoch in range(start_epoch, self.epochs): precomputation_condition = precompute_last_epoch and epoch + 1 == self.epochs context = nullcontext() if precomputation_condition: self.print_r0("> Precomputation condition is met.") context = PrecomputationContext(self, save_precomputed=True, get_precomputed=True, run_wandb=True) with context: stop_training = fit_epoch(epoch) if stop_training: break if do_cleanup: if ( self.accelerator.is_main_process and self.has_embeddings and not isinstance(self.data.datasets["full"], Heimdall.datasets.PairedInstanceDataset) ): if ( self.best_test_outputs and self.best_val_outputs and hasattr(self, "results_folder") and not self.skip_umaps ): self.save_umaps() self.print_r0(f"> Saved best UMAP checkpoint at epoch {self.best_epoch}") else: self.print_r0("> Skipped saving UMAP") if self.run_wandb and self.accelerator.is_main_process: for subtask_name, subtask in self.data.tasklist: if subtask.track_metric is not None: # logging the best val score and the tracked test scores self.accelerator.log(best_metric[subtask_name]) self.accelerator.end_training() if self.accelerator.is_main_process: self.print_r0("> Model has finished Training")
[docs] def save_umaps(self): best_test_embed = self.best_test_outputs["embeddings"] save_umap( self.data, best_test_embed, split="test", savepath=self.results_folder, log_umap=self.run_wandb, ) best_val_embed = self.best_val_outputs["embeddings"] save_umap( self.data, best_val_embed, split="val", savepath=self.results_folder, log_umap=self.run_wandb, )
[docs] def instantiate_loss_functions_from_config(self): loss_functions = {} for subtask_name, subtask in self.data.tasklist: loss_kwargs = {} loss_name = subtask.loss_config.type.split(".")[-1] if loss_name.startswith("Flatten"): loss_kwargs["num_labels"] = self.num_labels[subtask_name] loss_kwargs["trainer"] = self loss_functions[subtask_name] = instantiate_from_config(subtask.loss_config, **loss_kwargs) return loss_functions
[docs] def get_precomputed_outputs(self, inputs): outputs = {} for subtask_name, _ in self.data.tasklist: cell_index = inputs["idx"][subtask_name].to(torch.int32).tolist() cls_embeddings = self.data.adata.obsm[f"{subtask_name}_cls_embeddings"][cell_index] cls_embeddings = torch.from_numpy(cls_embeddings).to(device=self.data.accelerator.device) head_output = TransformerOutput( logits=torch.zeros_like(cls_embeddings), sequence_embeddings=torch.zeros_like(cls_embeddings), cls_embeddings=cls_embeddings, ) outputs[subtask_name] = head_output return outputs
[docs] def save_precomputed_outputs(self, inputs, outputs): for subtask_name, _ in self.data.tasklist: head_output = outputs[subtask_name] cls_embeddings = head_output.cls_embeddings.detach().cpu().numpy() if f"{subtask_name}_cls_embeddings" not in self.data.adata.obsm: _, d_model = cls_embeddings.shape self.data.adata.obsm[f"{subtask_name}_cls_embeddings"] = np.zeros( (self.data.adata.n_obs, *cls_embeddings.shape[1:]), ) cell_index = inputs["idx"][subtask_name].to(torch.int32).tolist() self.data.adata.obsm[f"{subtask_name}_cls_embeddings"][cell_index] = cls_embeddings
[docs] def get_outputs_and_loss(self, batch, cumulative_loss=None): for values in batch.values(): for subtask_name, value in values.items(): if value is not None: if isinstance(value, list): value = [subvalue.to(self.accelerator.device) for subvalue in value] else: value = value.to(self.accelerator.device) values[subtask_name] = value inputs = {input_key: batch[input_key] for input_key in INPUT_KEYS if input_key in batch} # inputs = (batch["identity_inputs"], batch["expression_inputs"]) if self.get_precomputed: try: outputs = self.get_precomputed_outputs(inputs) except KeyError: outputs = self.model(inputs=inputs) else: outputs = self.model(inputs=inputs) if self.save_precomputed: self.save_precomputed_outputs(inputs, outputs) labels = batch["labels"] batch_outputs = {subtask_name: {} for subtask_name, _ in self.data.tasklist} for subtask_name, _ in self.data.tasklist: if self.has_embeddings: batch_outputs[subtask_name]["embeddings"] = outputs[subtask_name].cls_embeddings batch_outputs[subtask_name]["labels"] = labels[subtask_name] batch_loss = {} if self.get_precomputed: for subtask_name, _ in self.data.tasklist: batch_outputs[subtask_name]["preds"] = torch.zeros_like(labels[subtask_name]) return batch_outputs, batch_loss preds = {} for subtask_name, subtask in self.data.tasklist: logits = outputs[subtask_name].logits subtask_labels = labels[subtask_name] if subtask.task_type in ("multiclass", "regression"): subtask_preds = logits # logits.argmax(dim=1) elif subtask.task_type == "mlm": subtask_preds = logits.argmax(dim=2) elif subtask.task_type == "binary": # multi-label binary classification → use sigmoid + threshold probs = torch.sigmoid(logits) subtask_preds = (probs > 0.5).float() else: raise ValueError(f"Unsupported task_type: {subtask.task_type}") preds[subtask_name] = subtask_preds if (masks := batch["masks"][subtask_name]) is not None: logits, subtask_labels = logits[masks], subtask_labels[masks] # perform a .clone() so that the subtask_labels are not updated in-place # TODO: weight task-specific loss_functions somehow loss_function = self.loss_functions[subtask_name] batch_loss[subtask_name] = loss_function(logits, subtask_labels.clone()) if cumulative_loss is None: cumulative_loss = batch_loss else: for subtask_name, _ in self.data.tasklist: cumulative_loss[subtask_name] += batch_loss[subtask_name] for subtask_name, _ in self.data.tasklist: batch_outputs[subtask_name]["preds"] = preds[subtask_name] return batch_outputs, cumulative_loss
[docs] def iterate_dataloader( self, dataloader, loss=None, epoch=None, metrics=None, log_every: int = 1, ): """Iterate through `DataLoader` (either for training or for validation).""" training = epoch is not None constrained_params = [p for name, p in self.model.named_parameters() if "metafeature" in name] if training: step = len(dataloader) * epoch outputs = None else: step = 0 outputs = defaultdict(lambda: defaultdict(list)) # Dict of dicts with tqdm(dataloader, disable=not self.accelerator.is_main_process) as pbar: total_loss = 0 for batch in pbar: step += 1 is_logging = step % log_every == 0 lr = self.lr_scheduler.get_last_lr()[0] with self.accelerator.accumulate(self.model) if training else nullcontext(): batch_outputs, loss = self.get_outputs_and_loss(batch, loss) prev_total_loss = total_loss total_loss = sum(loss.values()) if training: self.accelerator.backward(total_loss) if not isinstance(total_loss, int): total_loss = total_loss.item() if self.accelerator.sync_gradients: grad_norm = self.accelerator.clip_grad_norm_( self.model.parameters(), self.grad_norm_clip, ) self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() self.step += 1 pbar.set_description( f"Epoch: {epoch} " f"Step {self.step} " f"Loss: {total_loss:.4f} " f"LR: {lr:.1e} " f"grad_norm: {grad_norm:.4f} ", ) if is_logging: log = { "train_loss": total_loss, **{ f"train_{subtask_name}_loss": subtask_loss.item() for subtask_name, subtask_loss in loss.items() }, "global_step": self.step, "learning_rate": lr, "epoch": epoch, "grad_norm": grad_norm, } if self.run_wandb and self.accelerator.is_main_process: self.accelerator.log(log) with torch.no_grad(): for param in constrained_params: # TODO: make this more robust to different types of PGD project2simplex_(param, dim=0) loss = None else: batch_loss = total_loss - prev_total_loss pbar.set_description( f"Loss: {batch_loss:.4f} ", ) for subtask_name, _ in self.data.tasklist: for key, value in batch_outputs[subtask_name].items(): outputs[key][subtask_name].extend(value.detach().cpu().numpy()) if metrics is not None and not self.get_precomputed: for subtask_name, subtask in self.data.tasklist: for metric_name, metric in metrics[subtask_name].items(): # noqa: B007 # Built-in metric subtask_labels = batch_outputs[subtask_name]["labels"] subtask_preds = batch_outputs[subtask_name]["preds"] if subtask.task_type in ["multiclass", "mlm"]: subtask_labels = subtask_labels.to(torch.int) # Remove negative MLM values (TODO: fix so UCE doesn't provide these) if subtask.task_type in ["mlm"]: if torch.any(subtask_labels < 0): flattened_labels = subtask_labels.flatten() flattened_preds = subtask_preds.flatten() mask = flattened_labels >= 0 nonnegative_flattened_labels = flattened_labels[mask] nonnegative_flattened_preds = flattened_preds[mask] subtask_labels = nonnegative_flattened_labels.to(torch.int) subtask_preds = nonnegative_flattened_preds # Remove NaN values if subtask.task_type in ["binary"]: # Step 1: Flatten the tensor flattened_labels = subtask_labels.flatten() flattened_preds = subtask_preds.flatten() mask = ~torch.isnan(flattened_labels) no_nans_flattened_labels = flattened_labels[mask] no_nans_flattened_preds = flattened_preds[mask] subtask_labels = no_nans_flattened_labels.to(torch.int) subtask_preds = no_nans_flattened_preds metric.update(subtask_preds, subtask_labels) if self.fastdev: break # if not training: # for subtask_name, _ in self.data.tasklist: # for key in outputs: # # print(f'{key=}') # # print(f'{subtask_name=}') # # print(f'{outputs[key][subtask_name][0].shape=}') # # print(f'{np.unique([out.shape for out in outputs[key][subtask_name]])=}') # outputs[key][subtask_name] = np.array(outputs[key][subtask_name]) return outputs, loss
[docs] def validate_model(self, dataloader, dataset_type): self.model.eval() metrics = self._initialize_metrics() if len(dataloader) == 0: raise ValueError("`DataLoader` length cannot be zero. Check custom sampler implementation.") with torch.no_grad(): outputs, dataloader_loss = self.iterate_dataloader( dataloader, metrics=metrics, ) if dataloader_loss is None: loss = {subtask_name: 0.0 for subtask_name, _ in self.data.tasklist} # 0. is a sentinel value else: loss = { subtask_name: subtask_loss / len(dataloader) for subtask_name, subtask_loss in dataloader_loss.items() } total_loss = sum(loss.values()) if self.accelerator.num_processes > 1: loss_tensor = torch.tensor( [total_loss], device=self.accelerator.device, ) # loss is a python floating point value, for gather # operation across multiple processes needs to be # cuda tensor total_loss = self.accelerator.gather(loss_tensor) valid_processes = torch.nonzero(total_loss) num_valid = torch.count_nonzero(total_loss).item() total_loss = total_loss[valid_processes].sum().item() / num_valid log = {f"{dataset_type}_{subtask_name}_loss": subtask_loss for subtask_name, subtask_loss in loss.items()} log[f"{dataset_type}_loss"] = total_loss if self.save_precomputed: return log, outputs for subtask_name, subtask in self.data.tasklist: for metric_name, metric in metrics[subtask_name].items(): if metric_name != "ConfusionMatrix": # Built-in metric log[f"{dataset_type}_{subtask_name}_{metric_name}"] = metric.compute().item() if metric_name.startswith(("Accuracy", "Precision", "Recall", "F1Score")): log[ f"{dataset_type}_{subtask_name}_{metric_name}" ] *= 100 # Convert to percentage for these metrics if subtask.top_k is not None: if self.run_wandb and self.accelerator.is_main_process: top_k_accuracies = [] for k in subtask.top_k: top_k_accuracies.append(log[f"{dataset_type}_{subtask_name}_Accuracy_top_{k}"]) tbl = wandb.Table(data=list(zip(subtask.top_k, top_k_accuracies)), columns=["k", "topk_acc"]) chart = wandb.plot.line( tbl, "k", "topk_acc", title=f"{dataset_type}_{subtask_name}: Top-k Accuracy", ) self.accelerator.log({f"{dataset_type}_{subtask_name}_topk_acc_curve": chart}) if "ConfusionMatrix" in metrics[subtask_name]: # 1. Gather counts from all processes and sum cm_local = metrics[subtask_name]["ConfusionMatrix"].compute() # (C, C) tensor cm_counts = self.accelerator.reduce(cm_local, reduction="sum") # global counts # 3) If binary and flat, reshape to (2, 2) if cm_counts.dim() == 1: c = int(cm_counts.numel() ** 0.5) # should be 2 cm_counts = cm_counts.view(c, c) # 2. Row-wise normalisation → per-class accuracy matrix cm_norm = cm_counts.float() cm_norm = cm_norm / (cm_norm.sum(dim=1, keepdim=True) + 1e-8) # 3. Per-class accuracy vector (for dashboard scalars) per_class_acc = cm_norm.diag().cpu().numpy() * 100 log[f"{dataset_type}_{subtask_name}_per_class_accuracy"] = { name: float(acc) for name, acc in zip(self.class_names[subtask_name], per_class_acc) } # 4. Log interactive confusion matrix to WandB (main process only) if self.run_wandb and self.accelerator.is_main_process: y_true_np = outputs["labels"][subtask_name] y_pred_np = outputs["preds"][subtask_name] # Convert logits/probs to hard labels if needed if y_pred_np.ndim > 1: # shape (N, C) y_pred_np = y_pred_np.argmax(axis=1) # Flatten & convert to Python lists y_true_list = y_true_np.reshape(-1).tolist() y_pred_list = y_pred_np.reshape(-1).tolist() wandb_cm = wandb.plot.confusion_matrix( y_true=y_true_list, preds=y_pred_list, class_names=self.class_names[subtask_name], # same order as metric ) self.accelerator.log( {f"{dataset_type}_{subtask_name}_confusion_matrix": wandb_cm}, ) rss = self.process.memory_info().rss / (1024**3) log["Process_mem_rss"] = rss if self.run_wandb and self.accelerator.is_main_process: self.accelerator.log(log) if not self.run_wandb and self.accelerator.is_main_process: print(f"{dataset_type}_log = {pformat(log)}") return log, outputs
[docs] def train_epoch(self, epoch): self.model.train() self.iterate_dataloader( self.dataloader_train, epoch=epoch, )
[docs] def get_checkpoint_directory(self, additional_keys: tuple = (), hash_vars: tuple = ()): cache_dir = Path(self.cfg.cache_preprocessed_dataset_dir) keys = self.CHECKPOINT_KEYS + additional_keys checkpoint_directory, _, minimal_cfg = get_fully_qualified_cache_paths( self.cfg, cache_dir / "checkpoints", keys=keys, hash_vars=hash_vars, mkdir=False, ) return checkpoint_directory
[docs] def initialize_checkpointing(self, additional_keys: tuple = (), hash_vars: tuple = ()): """Initialize checkpoint directory.""" if getattr(self.cfg, "work_dir", None) is not None: self.results_folder = Path(self.cfg.work_dir) else: self.results_folder = self.get_checkpoint_directory(additional_keys=additional_keys, hash_vars=hash_vars) if self.accelerator.is_main_process: try: self.results_folder.mkdir(parents=True, exist_ok=False) self.print_r0(f"> Checkpoint directory initialized at {self.results_folder}") except FileExistsError: self.print_r0(f"> Checkpoint directory already exists at {self.results_folder}")
[docs] def get_latest_checkpoint_path(self): self.initialize_checkpointing() milestone_path = self.results_folder / "milestone.txt" if not milestone_path.exists(): return None milestone = milestone_path.read_text().strip() if not milestone: return None checkpoint_path = self.results_folder / f"model-{milestone}.pt" if checkpoint_path.exists(): return checkpoint_path return None
[docs] def save_checkpoint(self, epoch): """Save model checkpoint at the given epoch.""" # Only save on the main process if not self.accelerator.is_main_process: return self.initialize_checkpointing() # Calculate current step based on epoch # step = len(self.dataloader_train) * epoch # Prepare the data to save data = { "epoch": epoch, "step": self.step, "model": self.accelerator.get_state_dict(self.model), "optimizer": self.optimizer.state_dict(), "scaler": self.accelerator.scaler.state_dict() if (self.accelerator.scaler is not None) else None, "lr_scheduler": self.lr_scheduler.state_dict(), "python_rng_state": random.getstate(), "numpy_rng_state": np.random.get_state(), "torch_rng_state": torch.random.get_rng_state(), "cuda_rng_state_all": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, "version": 1.0, } # Save checkpoint checkpoint_path = self.results_folder / f"model-{epoch}.pt" torch.save(data, str(checkpoint_path)) self.print_r0(f"> Saved checkpoint to {checkpoint_path}") # Overwrite 'milestone.txt' with the new milestone milestone_file = self.results_folder / "milestone.txt" with open(milestone_file, "w") as f: f.write(str(epoch)) self.print_r0(f"> Updated milestone.txt to milestone {epoch}") config_path = self.results_folder / "config.txt" with open(config_path, "w") as f: f.write(OmegaConf.to_yaml(self.cfg))
[docs] def load_checkpoint(self, specific_milestone=None): """Load a checkpoint based on milestone.txt or a specific milestone number.""" # Ensure results folder is initialized self.initialize_checkpointing() # Determine which milestone to load if specific_milestone is not None: milestone = specific_milestone else: milestone_file = self.results_folder / "milestone.txt" if not milestone_file.exists(): self.print_r0("> No milestone.txt found. Starting from scratch.") return 0 # Read the milestone number with open(milestone_file) as f: milestone_str = f.read().strip() if not milestone_str.isdigit(): self.print_r0("milestone.txt is invalid. Starting from scratch.") return 0 milestone = int(milestone_str) # Load the checkpoint load_path = self.results_folder / f"model-{milestone}.pt" if not load_path.exists(): self.print_r0(f"> Checkpoint file {load_path} does not exist. Starting from scratch.") return 0 self.print_r0(f"> Loading checkpoint from {load_path}") # Load the data device = self.accelerator.device data = torch.load(str(load_path), map_location=device, weights_only=False) # Unwrap model and restore parameters model = self.accelerator.unwrap_model(self.model) model.load_state_dict(data["model"]) epoch = self.load_trainer_state(data) if "version" in data: self.print_r0(f"> Checkpoint version: {data['version']}") self.print_r0(f"> Resumed from epoch {epoch + 1}, step {self.step}") return epoch + 1
[docs] def load_trainer_state(self, data): # Restore optimizer and scheduler states # clean_opt_sd = clean_optimizer_state_for_current_model(data["optimizer"], self.optimizer, verbose=True) # self.optimizer.load_state_dict(clean_opt_sd) if (data["scaler"] is not None) and (self.accelerator.scaler is not None): self.accelerator.scaler.load_state_dict(data["scaler"]) self.lr_scheduler.load_state_dict(data["lr_scheduler"]) # Restore random states random.setstate(data["python_rng_state"]) np.random.set_state(data["numpy_rng_state"]) # Handle torch RNG state torch_rng_state = data["torch_rng_state"] if isinstance(torch_rng_state, torch.Tensor) and torch_rng_state.device.type != "cpu": torch_rng_state = torch_rng_state.cpu() torch.random.set_rng_state(torch_rng_state) # Handle CUDA RNG states if data["cuda_rng_state_all"] is not None and torch.cuda.is_available(): num_visible_devices = torch.cuda.device_count() if len(data["cuda_rng_state_all"]) != num_visible_devices: self.print_r0( "Warning: Number of visible CUDA devices does not match the number of saved CUDA RNG states. " "Skipping CUDA RNG state restoration.", ) else: new_cuda_states = [] for state in data["cuda_rng_state_all"]: if isinstance(state, torch.Tensor) and state.device.type != "cpu": state = state.cpu() new_cuda_states.append(state) torch.cuda.set_rng_state_all(new_cuda_states) epoch = data["epoch"] self.step = data["step"] return epoch
[docs] def get_pretrained_load_path(self): config = self.local_cfg load_path = None if "pretrained_milestone" in config and config.pretrained_milestone is not None: self.initialize_checkpointing() load_path = self.results_folder / f"model-{config.pretrained_milestone}.pt" if not load_path.exists(): self.print_r0( f"> Checkpoint file {load_path} does not exist. `{config.pretrained_milestone=}` is invalid.", ) elif "pretrained_ckpt_path" in config and config.pretrained_ckpt_path is not None: load_path = Path(config.pretrained_ckpt_path) if not load_path.exists(): self.print_r0( f"> Checkpoint file {load_path} does not exist. Check the value of " f"`{config.pretrained_ckpt_path=}` for correctness.", ) return load_path
[docs] def load_pretrained(self): load_path = self.get_pretrained_load_path() if load_path is None: return self.print_r0(f"> Loading pretrained model state from {load_path}") # Load the data device = self.accelerator.device data = torch.load(str(load_path), map_location=device, weights_only=False) pretrained_state_dict = data["model"] filtered_pretrained_params = OrderedDict( filter(lambda param_tuple: "decoder" not in param_tuple[0], pretrained_state_dict.items()), ) # we drop the pretrained head and load all other params # Unwrap model and restore parameters model = self.accelerator.unwrap_model(self.model) model.load_state_dict(filtered_pretrained_params, strict=False) # model.load_state_dict(data["model"]) self.load_trainer_state(data) if self.accelerator.is_main_process: print(f"> Finished loading pretrained params loaded from {load_path}")
def setup_trainer_generic(config, setup_model: Callable, cpu=True): accelerator, cr, run_wandb, only_preprocess_data = setup_data(config, cpu=cpu) if only_preprocess_data: return model = setup_model(config, cr, is_main_process=accelerator.is_main_process) trainer = instantiate_from_config( config.scfm.trainer, cfg=config, model=model, data=cr, accelerator=accelerator, run_wandb=run_wandb, ) trainer.load_pretrained() return trainer def setup_trainer(config, cpu=True): return setup_trainer_generic(config, setup_model=setup_model, cpu=cpu) class PrecomputationContext: ATTRIBUTES = ("save_precomputed", "get_precomputed", "run_wandb") def __init__( self, trainer: HeimdallTrainer, save_precomputed: bool, get_precomputed: bool, run_wandb: bool = False, ): self.trainer = trainer self.save_precomputed = save_precomputed self.get_precomputed = get_precomputed self.run_wandb = run_wandb def swap(self): for attribute in self.ATTRIBUTES: context_attr = getattr(self, attribute) trainer_attr = getattr(self.trainer, attribute) setattr(self.trainer, attribute, context_attr) setattr(self, attribute, trainer_attr) def __enter__(self): self.swap() def __exit__(self, exc_type, exc_val, exc_tb): # if self.trainer.save_precomputed and isinstance(trainer, PartitionedCellRepresentation): # self.trainer.data.partition = None self.swap() return False def clean_optimizer_state_for_current_model(saved_opt_sd: dict, optimizer: torch.optim.Optimizer, verbose: bool = True): """Return a cleaned optimizer state_dict compatible with the given `optimizer`. - saved_opt_sd: the checkpoint["optimizer"] you loaded from disk - optimizer: the optimizer instance that was created for the current model (and whose .param_groups define grouping) """ saved_state = deepcopy(saved_opt_sd.get("state", {})) saved_param_group_list = saved_opt_sd.get("param_groups", []) if verbose: print( f"[opt_reload] saved state entries: {len(saved_state)}, saved param_groups: {len(saved_param_group_list)}", ) # Build a list of the current parameters in the order of optimizer.param_groups current_param_groups = optimizer.param_groups # these have 'params' as param objects current_params_flat = [] for g in current_param_groups: for p in g["params"]: current_params_flat.append(p) # Prepare containers for new state and param_groups new_state = {} new_param_groups = [] # We'll mark which saved pids have been used (so we don't reuse them) available_saved_pids = set(saved_state.keys()) # Helper to get a representative tensor from saved-state entry for shape check def representative_tensor_from_state_entry(state_entry): # typical keys: 'exp_avg', 'exp_avg_sq' (or for some optimizers 'momentum_buffer') for v in state_entry.values(): if isinstance(v, torch.Tensor): return v return None matched = 0 dropped = 0 # For each current param group, create a new group copying hyperparams but setting 'params' to ids for group in current_param_groups: # copy group hyperparams except params new_group = {k: deepcopy(v) for k, v in group.items() if k != "params"} new_group_param_ids = [] for param in group["params"]: p_shape = param.shape match_pid = None match_state = None # Find a saved PID with a representative tensor that matches the shape for saved_pid in list(available_saved_pids): state_entry = saved_state[saved_pid] rep = representative_tensor_from_state_entry(state_entry) if rep is None: # if no tensor in state entry, we can't compare shapes; skip continue if tuple(rep.shape) == tuple(p_shape): match_pid = saved_pid match_state = state_entry break if match_pid is not None: # Assign saved state to this current parameter id new_pid = id(param) new_state[new_pid] = match_state available_saved_pids.remove(match_pid) new_group_param_ids.append(new_pid) matched += 1 else: # No matching saved state for this param (likely a newly initialized param) new_group_param_ids.append(id(param)) dropped += 1 new_group["params"] = new_group_param_ids new_param_groups.append(new_group) if verbose: print(f"[opt_reload] matched saved states -> {matched}") print(f"[opt_reload] params without saved state (new) -> {dropped}") print(f"[opt_reload] leftover saved states not matched -> {len(available_saved_pids)}") cleaned_sd = {"state": new_state, "param_groups": new_param_groups} # Copy over (safe) additional keys if present (like 'defaults') from the saved dict, # but param_groups is the critical part that must match the current optimizer. if "defaults" in saved_opt_sd: cleaned_sd["defaults"] = deepcopy(saved_opt_sd["defaults"]) return cleaned_sd