Source code for Heimdall.reduce

from abc import ABC, abstractmethod

import torch
from torch import Tensor
from torch.nn import Module

from Heimdall.fc import Fc


[docs] class Reduce(ABC): def __init__( self, fc: Fc, ): self.fc = fc @abstractmethod def __call__( self, identity_inputs: Tensor, gene_embedding_layer: Module | None, expression_inputs: Tensor, expression_embedding_layer: Module | None, metadata_embedding_layer: Module | None, ) -> Tensor: """Embed cell batch using the embedding layers. It can be assumed that both the identity inputs and the expression inputs have been padded/ limited at this stage, i.e. they are regular-shaped tensors. Args: identity_inputs: batched gene identity inputs gene_embedding_layer: Torch module for embedding based on gene identity. expression_inputs: batched gene expression inputs expression_embedding_layer: Torch module for embedding based on expression. metadata_embedding_layer: Torch module for embedding based on metadata. Returns: Embeddings of cells. """
[docs] class IdentityReduce(Reduce): def __call__( self, identity_inputs: Tensor, gene_embedding_layer: Module | None, expression_inputs: Tensor, expression_embedding_layer: Module | None, metadata_embedding_layer: Module | None, ) -> Tensor: """Geneformer cell embedding function. Ignores expression embedding layer; uses embeddings based on identity embeddings. Args: gene_embedding_layer: # TODO: fill out expression_embedding_layer: # TODO fill out """ embeddings = gene_embedding_layer(identity_inputs) return embeddings
[docs] class SumReduce(Reduce): def __call__( self, identity_inputs: Tensor, gene_embedding_layer: Module | None, expression_inputs: Tensor, expression_embedding_layer: Module | None, metadata_embedding_layer: Module | None, ) -> Tensor: """ScGPT cell embedding callback. TODO: add "conditional tokens" (see Methods of https://www.nature.com/articles/s41592-024-02201-0#Sec14) Args: gene_embedding_layer: # TODO: fill out expression_embedding_layer: # TODO fill out """ # Convert str float_dtype -> actual torch dtype # torch_dtype = getattr(torch, self.float_dtype) # Cast expression_inputs to float_dtype expression_inputs = expression_inputs.to(torch.float32) gene_embeddings = gene_embedding_layer(identity_inputs) expression_embeddings = expression_embedding_layer(expression_inputs) return gene_embeddings + expression_embeddings
[docs] class ChromosomeReduce(Reduce): def __call__( self, identity_inputs: Tensor, gene_embedding_layer: Module | None, expression_inputs: Tensor, expression_embedding_layer: Module | None, metadata_embedding_layer: Module | None, ) -> Tensor: """Embed cells using chromosome-aware sequences.""" chrom_token_mask = identity_inputs < 0 chrom_token_indices = identity_inputs[identity_inputs < 0] chrom_token_indices = -chrom_token_indices - self.fc.chrom_token_offset identity_inputs[chrom_token_mask] = 0 gene_embeddings = gene_embedding_layer(identity_inputs) gene_embeddings[chrom_token_mask] = metadata_embedding_layer(chrom_token_indices) return gene_embeddings
[docs] class ChromosomeSumReduce(Reduce): def __call__( self, identity_inputs: Tensor, gene_embedding_layer: Module | None, expression_inputs: Tensor, expression_embedding_layer: Module | None, metadata_embedding_layer: Module | None, ) -> Tensor: """Embed cells using chromosome-aware sequences.""" chrom_token_mask = identity_inputs < 0 chrom_token_indices = identity_inputs[identity_inputs < 0] chrom_token_indices = -chrom_token_indices - self.fc.chrom_token_offset identity_inputs[chrom_token_mask] = 0 expression_inputs[chrom_token_mask] = 0 gene_embeddings = gene_embedding_layer(identity_inputs) expression_embeddings = expression_embedding_layer(expression_inputs) meta_emb = metadata_embedding_layer(chrom_token_indices) meta_emb_gene = meta_emb.to(dtype=gene_embeddings.dtype, device=gene_embeddings.device) meta_emb_expr = meta_emb.to(dtype=expression_embeddings.dtype, device=expression_embeddings.device) gene_embeddings[chrom_token_mask] = meta_emb_gene expression_embeddings[chrom_token_mask] = meta_emb_expr return gene_embeddings + expression_embeddings