Benchmarking tokenizers for cross-tissue generalization#

Here, we demonstrate how to use Heimdall for benchmarking the impact of tokenizer choice on cell-type annotation performance in a challenging cross-tissue evaluation setting. For this evaluation, we use a subset of the scTab dataset.

import hydra
import Heimdall
from matplotlib import pyplot as plt
import matplotlib
import seaborn as sns
import scanpy as sc
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme(style="white")

Recap: the Heimdall framework#

We are interested in implementing various popular tokenizers for single-cell foundation models within a standardized framework, which enables us to isolate the impact of the tokenizer itself on downstream performance.

Fig 1A. Modular conceptualization of scFMs in Heimdall. Gene identities and expression values from each single-cell input are processed by a cell tokenization scheme (tokenizer) that generates a sequence-based “cell sentence”. The tokenizer is decomposed into three modules: a gene identity encoder ($F_\textbf{G}$), an expression encoder ($F_\textbf{E}$), and a cell constructor ($F_\textbf{C}$). The tokenizer output is then passed to a sequence-based model (e.g., a transformer).

Experiment setup (via hydra)#

We use hydra to configure each Heimdall run. First, we detail the configuration files that are shared across all tokenizers for this cross-tissue generalization experiment.

Let’s take a closer look at some of the most important configs…

experiments#

Top-level experiment config that specifies other essential configs.

defaults:
  - override /dataset: new_sctab  # dataset for the relevant task
  - override /tasks: new_sctab_split  # dataset for the relevant task
  - override /model: transformer  # the chosen model
  - override /scheduler: cosine  # contains scheduler
  - override /trainer: default  # contains optimizer and trainer details
  - override /optimizer: AdamW
  - override /fg: random
  - override /fe: noop
  - override /fc: geneformer
  
seed: 55 # random seed for reproducibility

project_name: new_sctab_split1  # project name for WandB

dataset#

Specifies the path to the dataset, as well as preprocessing arguments.

dataset_name: new_sctab

preprocess_args:
  data_path: ${data_path}/sctab/tissue_splits_spencer/scTab_GItract_train.h5ad
  top_n_genes: false
  normalize: true
  log_1p: true
  scale_data: false
  species: human

task#

Specifies the task to use for training the model, including training metrics, dataset splits, the architecture of the task “head” (used for predicting the task outputs), and the loss function. In this case, the predefined dataset splits (stored in the adata.obs["split3"] column) reserve gastrointestinal tract cells exclusively for training, while brain cells are used exclusively from validation/testing.

type: Heimdall.task.SingleInstanceTask

args:
  task_type: multiclass
  label_col_name: cell_type
  metrics: [Accuracy, MatthewsCorrCoef, ConfusionMatrix]
  track_metric: MatthewsCorrCoef
  splits:
    type: predefined
    col: split3
    keys_:
      train: train
      val: val
      test: test
  early_stopping_patience: 5
  early_stopping: true
  shuffle: true
  batchsize: 32
  epochs: 50
  dataset_config:
    type: Heimdall.datasets.SingleInstanceDataset
  head_config:
    type: Heimdall.models.LinearCellPredHead
    args:
  loss_config:
    type: Heimdall.losses.FlattenCrossEntropyLoss

cell_rep_config:
  type: Heimdall.cell_representations.CellRepresentation

model#

Specifies the model architecture used for the scFM.

type: Heimdall.models.Transformer
name: transformer

args:
  d_model: 128
  pos_enc: BERT
  num_encoder_layers: 2
  nhead: 4
  hidden_act: gelu
  hidden_dropout_prob: 0.1
  use_flash_attn: false
  pooling: cls_pooling # or "mean_pooling"

Modular reimplementation of the Geneformer tokenizer#

Having configured everything except the tokenizer, we now focus on implementing the tokenizer. For practice, let’s implement the Geneformer tokenizer.

fg - the gene identity encoder (\(F_\textbf{G}\))#

Specifies the Fg implementation for this tokenizer, as well as the torch.nn.Module used for providing the embeddings of the genes. Here, we use the random implementation, which assigns a randomly-initialized embedding vector of dimensionality model.args.d_model to each gene in the cell.

type: Heimdall.fg.IdentityFg

args:
  embedding_parameters:
    type: Heimdall.embedding.FlexibleTypeEmbedding
    args:
      num_embeddings: vocab_size
      embedding_dim: ${fg.args.d_embedding}
  d_embedding: ${model.args.d_model}
  frozen: false

fe - the gene expression encoder (\(F_\textbf{E}\))#

Specifies the Fe implementation for this tokenizer, as well as the torch.nn.Module used for providing the embeddings of the genes’ expression levels. Here, we use the noop implementation, which simply outputs a vector of zeros of dimensionality model.args.d_model for each gene in the cell, regardless of the gene’s expression level.

type: Heimdall.fe.IdentityFe
name: Heimdall.fe.IdentityFe

args:
  embedding_parameters:
    type: Heimdall.embedding.ZeroBroadcast
    args:
      out_features: ${fe.args.d_embedding}
  d_embedding: ${model.args.d_model}
  drop_zeros: true

fc - the single-cell representation function (\(F_\textbf{C}\))#

Specifies the

type: Heimdall.fc.Fc

args:
  max_input_length: 2048
  embedding_parameters:
    type: torch.nn.Module  # Should throw an error if called
  tailor_config:
    type: Heimdall.tailor.ReorderTailor
  order_config:
    type: Heimdall.order.ExpressionOrder
  reduce_config:
    type: Heimdall.reduce.SumReduce

Putting it all together#

from omegaconf import OmegaConf
with hydra.initialize(version_base=None, config_path="../Heimdall/config"):
    config = hydra.compose(
        config_name="config",
        overrides=[
            "+experiments=sctab_split1_all",
            "fg=random",
            "fe=noop",
            "fc=geneformer",
        ],
    )

    OmegaConf.resolve(config)
print(OmegaConf.to_yaml(config))
project_name: new_sctab_split1
run_name: Heimdall.fg.IdentityFg_Heimdall.fe.IdentityFe_Heimdall.fc.Fc_Heimdall.models.Transformer_lr0.002_bz32
work_dir: new_sctab_split1_results/Heimdall.fg.IdentityFg_Heimdall.fe.IdentityFe_Heimdall.fc.Fc_new_sctab_lr0.002_bz32_seed55_agTrue
run_wandb: true
float_dtype: float32
seed: 55
data_path: /work/magroup/shared/Heimdall/data
ensembl_dir: /work/magroup/shared/Heimdall/data
cache_preprocessed_dataset_dir: /scratch/heimdall/shared/cache
entity: Heimdall
only_preprocess_data: false
model:
  type: Heimdall.models.Transformer
  name: transformer
  args:
    d_model: 128
    pos_enc: BERT
    num_encoder_layers: 2
    nhead: 4
    hidden_act: gelu
    hidden_dropout_prob: 0.1
    use_flash_attn: false
    pooling: cls_pooling
dataset:
  dataset_name: new_sctab
  preprocess_args:
    data_path: /work/magroup/shared/Heimdall/data/sctab/tissue_splits_spencer/scTab_GItract_train.h5ad
    top_n_genes: false
    normalize: true
    log_1p: true
    scale_data: false
    species: human
tasks:
  type: Heimdall.task.SingleInstanceTask
  args:
    task_type: multiclass
    label_col_name: cell_type
    metrics:
    - Accuracy
    - MatthewsCorrCoef
    - ConfusionMatrix
    track_metric: MatthewsCorrCoef
    splits:
      type: predefined
      col: split1
      keys_:
        train: train
        val: val
        test: test
    early_stopping_patience: 5
    early_stopping: true
    shuffle: true
    batchsize: 32
    epochs: 50
    dataset_config:
      type: Heimdall.datasets.SingleInstanceDataset
    head_config:
      type: Heimdall.models.LinearCellPredHead
      args: null
    loss_config:
      type: Heimdall.losses.FlattenCrossEntropyLoss
  cell_rep_config:
    type: Heimdall.cell_representations.CellRepresentation
scheduler:
  name: cosine
  lr_schedule_type: cosine
  warmup_ratio: 0.1
  num_epochs: 20
trainer:
  type: Heimdall.trainer.HeimdallTrainer
  args:
    random_seed: 55
    accumulate_grad_batches: 1
    grad_norm_clip: 1.0
    fastdev: false
    skip_umaps: false
  cpu: false
optimizer:
  name: AdamW
  args:
    lr: 0.002
    weight_decay: 0.1
    betas:
    - 0.9
    - 0.95
    foreach: false
fc:
  type: Heimdall.fc.Fc
  args:
    max_input_length: 2048
    embedding_parameters:
      type: torch.nn.Module
    tailor_config:
      type: Heimdall.tailor.ReorderTailor
    order_config:
      type: Heimdall.order.ExpressionOrder
    reduce_config:
      type: Heimdall.reduce.SumReduce
fe:
  type: Heimdall.fe.IdentityFe
  name: Heimdall.fe.IdentityFe
  args:
    embedding_parameters:
      type: Heimdall.embedding.ZeroBroadcast
      args:
        out_features: 128
    d_embedding: 128
    drop_zeros: true
fg:
  type: Heimdall.fg.IdentityFg
  args:
    embedding_parameters:
      type: Heimdall.embedding.FlexibleTypeEmbedding
      args:
        num_embeddings: vocab_size
        embedding_dim: 128
    d_embedding: 128
    frozen: false
loss:
  type: Heimdall.losses.FlattenCrossEntropyLoss

Training the model#

from Heimdall.trainer import setup_trainer
def training_loop(config): 
    trainer = setup_trainer(config, cpu=config.trainer.cpu)
    if trainer is not None:
        trainer.fit()
from accelerate import notebook_launcher
args = (config,)
notebook_launcher(training_loop, args, num_processes=1)