Heimdall.trainer.HeimdallTrainer

Contents

Heimdall.trainer.HeimdallTrainer#

class Heimdall.trainer.HeimdallTrainer(cfg, model, data, accelerator, batchsize, epochs, random_seed=0, early_stopping=False, early_stopping_patience=5, accumulate_grad_batches=1, grad_norm_clip=1.0, skip_umaps=True, fastdev=False, run_wandb=False, custom_loss_func=None, custom_metrics=None)[source]#

Bases: object

Attributes

Methods

check_flash_attn()

fit([get_precomputed])

fit_model([resume_from_checkpoint, ...])

Train the model with automatic checkpointing and resumption.

get_checkpoint_directory([additional_keys, ...])

get_latest_checkpoint_path()

get_outputs_and_loss(batch[, cumulative_loss])

get_precomputed_outputs(inputs)

get_pretrained_load_path()

initialize_checkpointing([additional_keys, ...])

Initialize checkpoint directory.

instantiate_loss_functions_from_config()

iterate_dataloader(dataloader[, loss, ...])

Iterate through DataLoader (either for training or for validation).

load_checkpoint([specific_milestone])

Load a checkpoint based on milestone.txt or a specific milestone number.

load_pretrained()

load_trainer_state(data)

print_r0(message)

save_checkpoint(epoch)

Save model checkpoint at the given epoch.

save_precomputed_outputs(inputs, outputs)

save_umaps()

setup_class_names_and_num_labels(data)

train_epoch(epoch)

validate_model(dataloader, dataset_type)

Parameters:
  • accelerator (Accelerator)

  • batchsize (int)

  • epochs (int)

  • random_seed (int)

  • early_stopping (bool)

  • early_stopping_patience (int)

  • accumulate_grad_batches (int)

  • grad_norm_clip (float)

  • skip_umaps (bool)

  • fastdev (bool)