Implementing a novel scFM tokenizer#

Beyond benchmarking existing tokenizers, Heimdall also enables users to introduce, evaluate and share novel tokenizer designs. Here, we demonstrate several examples of this. As in the first demo, we use a subset of the scTab dataset for evaluation.

import hydra
import Heimdall
from matplotlib import pyplot as plt
import matplotlib
import seaborn as sns
import scanpy as sc
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme(style="white")  

We are interested in implementing various popular tokenizers for single-cell foundation models within a standardized framework, which enables us to isolate the impact of the tokenizer itself on downstream performance.

Recap: Heimdall modularizes scFM tokenizers#

The key to Heimdall’s modularity is the compartmentalization of the tokenizer into \(F_\textbf{G}\), \(F_\textbf{E}\) and \(F_\textbf{C}\) components. Thus, to implement a novel

Fig 1C (bottom). Application of Heimdall for systematic design and evaluation of scFMs. The highlighted yellow rectangle indicates introduction of a novel $F_\textbf{E}$, thereby constituting a new tokenizer implementation.

Implementing a new \(F_\textbf{G}\)#

In our experiments, we use HyenaDNA to create a novel gene-identity encoding module. In particular, we feed the DNA sequence of the gene to the HyenaDNA model, and save the outputs as a Torch Tensor file. We already provide a base class that loads the gene embeddings from a .pt file, so the implementation is straightforward after we have extracted the embeddings.

Python code#

from Heimdall.fg import PretrainedFg
class TorchTensorFg(PretrainedFg):   # This is actually already provided in `Heimdall.fg`
    """Mapping of gene names to pretrained embeddings stored as PyTorch                         
    tensors."""                                                                                 
                                                                                                
    def load_embeddings(self):                                                                  
        raw_gene_embedding_map = torch.load(self.embedding_filepath, weights_only=True)
        
        raw_gene_embedding_map = {                                                              
            gene_name: embedding.detach().cpu().numpy() for gene_name, embedding in raw_gene_embedding_map.items()
        }                                                                                       
                                                                                                
        return raw_gene_embedding_map  

class HyenaDNAFg(TorchTensorFg):
    """Mapping of gene names to pretrained HyenaDNA embeddings."""  

fg/hyenadna.yaml config#

The final step is to write a config file for this \(F_\textbf{G}\)

type: Heimdall.fg.HyenaDNAFg                                                                                                                                                                                                                         
                                                                                                
args:                                                                                           
  embedding_parameters:                                                                         
    type: Heimdall.embedding.FlexibleTypeEmbedding                                              
    constructor: from_pretrained                                                                
    args:                                                                                       
      embeddings: gene_embeddings                                                               
  embedding_filepath: ${data_path}/pretrained_embeddings/hyenaDNA.pt                     
  d_embedding: ${model.args.d_model}                                                            
  frozen: true 

Putting it all together#

from omegaconf import OmegaConf
with hydra.initialize(version_base=None, config_path="../Heimdall/config"):
    config = hydra.compose(
        config_name="config",
        overrides=[
            "+experiments=sctab_split1_all",
            "fg=hyenadna",
            "fe=zero",
            "fc=geneformer",
        ],
    )
    
    OmegaConf.resolve(config)

Training the model#

from Heimdall.trainer import setup_trainer
def training_loop(config):
    trainer = setup_trainer(config, cpu=config.trainer.cpu)
    if trainer is not None:
        trainer.fit()
from accelerate import notebook_launcher
args = (config,)
notebook_launcher(training_loop, args, num_processes=1)