Heimdall.trainer.HeimdallTrainer#
- class Heimdall.trainer.HeimdallTrainer(cfg, model, data, accelerator, batchsize, epochs, random_seed=0, early_stopping=False, early_stopping_patience=5, accumulate_grad_batches=1, grad_norm_clip=1.0, skip_umaps=True, fastdev=False, run_wandb=False, custom_loss_func=None, custom_metrics=None)[source]#
Bases:
objectAttributes
Methods
fit([get_precomputed])fit_model([resume_from_checkpoint, ...])Train the model with automatic checkpointing and resumption.
get_checkpoint_directory([additional_keys, ...])get_outputs_and_loss(batch[, cumulative_loss])get_precomputed_outputs(inputs)initialize_checkpointing([additional_keys, ...])Initialize checkpoint directory.
iterate_dataloader(dataloader[, loss, ...])Iterate through DataLoader (either for training or for validation).
load_checkpoint([specific_milestone])Load a checkpoint based on milestone.txt or a specific milestone number.
load_trainer_state(data)print_r0(message)save_checkpoint(epoch)Save model checkpoint at the given epoch.
save_precomputed_outputs(inputs, outputs)train_epoch(epoch)validate_model(dataloader, dataset_type)