"""Heimdall model."""
from collections import defaultdict
import torch
import torch.nn as nn
from omegaconf import DictConfig
from Heimdall.cell_representations import CellRepresentation
from Heimdall.datasets import PairedInstanceDataset
from Heimdall.embedding import PositionalEncoding
from Heimdall.utils import get_dtype, instantiate_from_config
[docs]
class HeimdallModel(nn.Module):
def __init__(
self,
data: CellRepresentation,
model_config: DictConfig,
):
super().__init__()
"""Heimdall model. Combines language model and task-specific head.
Args:
data: Cell representation data object.
model_config: The language model config.
"""
self.num_subtasks = data.num_subtasks
self.tasklist = data.tasklist
self.encoder = instantiate_from_config(
model_config,
data,
)
dim_in = self.encoder.d_encoded
self.reducers = nn.ModuleDict()
self.heads = nn.ModuleDict()
for subtask_name, subtask in data.tasklist:
if isinstance(data.datasets["full"], PairedInstanceDataset):
self.reducers[subtask_name] = instantiate_from_config(
subtask.reducer_config,
dim_in=dim_in,
)
num_labels = subtask.num_tasks
head = instantiate_from_config(subtask.head_config, dim_in=dim_in, dim_out=num_labels)
self.heads[subtask_name] = head
@property
def dtype(self):
return next(self.parameters()).dtype
[docs]
def encode_cell(self, cell_inputs):
"""Given the either single- or multiple-cells, use the cell encoder to
embed the cell(s)."""
outputs = {}
cached_encoding = None
# NOTE: this was the masked used for MLM, different from attention mask
masks = cell_inputs.pop("masks", None)
for subtask_name, _ in self.tasklist:
subtask_inputs = {key: cell_inputs[key][subtask_name] for key in cell_inputs}
attention_mask = subtask_inputs.pop("expression_padding", None)
if masks is None and cached_encoding is not None:
outputs[subtask_name] = cached_encoding # TODO: only reuses encoding if all are unmasked
else:
outputs[subtask_name] = self.encoder(subtask_inputs, attention_mask=attention_mask)
if masks is None:
cached_encoding = outputs[subtask_name]
return outputs
[docs]
def forward(self, inputs):
if self.reducers:
encoded_cells = []
for index in range(2): # Two cells (can be generalized to more)
cell_inputs = defaultdict(dict)
for key, value in inputs.items():
for subtask_name, _ in self.reducers.items():
cell_value = value[subtask_name]
if cell_value is not None:
cell_value = cell_value[index]
cell_inputs[key][subtask_name] = cell_value
encoded_cell = self.encode_cell(cell_inputs)
encoded_cells.append(encoded_cell)
# Apply reducers
outputs = {}
for subtask_name, reducer in self.reducers.items():
outputs[subtask_name] = reducer([encoded_cell[subtask_name] for encoded_cell in encoded_cells])
else:
outputs = self.encode_cell(inputs)
# Apply heads
outputs = {subtask_name: self.heads[subtask_name](output) for subtask_name, output in outputs.items()}
return outputs
class ExpressionOnly(nn.Module):
def __init__(
self,
data: CellRepresentation,
):
super().__init__()
"""Heimdall model. Combines language model and task-specific head.
Args:
data: Cell representation data object.
model_config: The language model config.
"""
self.vocab_size = data.adata.n_vars + 2
self.float_dtype = data.float_dtype
_, self.d_encoded = data.adata.shape
def forward(self, inputs, labels=None, attention_mask=None):
outputs = inputs["expression_inputs"] # extract expression only
return outputs.to(get_dtype(self.float_dtype)) # convert to float32?
class CellSentenceModel(nn.Module):
def __init__(
self,
data: CellRepresentation,
d_model: int,
pos_enc: str,
pooling: str,
):
super().__init__()
"""Cell sentence encoder abstraction.
Must set the following properties:
self.d_encoded: the dimensionality of the embedding output
self.cell_sentence_model: nn.Module for encoding sequences
Args:
data: Cell representation data object.
d_model: dimensionality of embedding output
pos_enc: positional encoding strategy to use
pooling: type of pooling to use for combining gene tokens into cell embedding
"""
self.d_encoded = d_model
self.fc = data.fc
self.vocab_size = data.adata.n_vars + 2 # <PAD> and <MASK> TODO: data.vocab_size
# Setting up embedding layers
if data.fg.d_embedding is not None:
self.gene_embeddings = instantiate_from_config(data.fg.embedding_parameters)
if data.fg.frozen:
print("> Freezing all params in F_g")
for param in self.gene_embeddings.parameters():
param.requires_grad = False
else:
self.gene_embeddings = None
if data.fe.d_embedding is not None:
self.expression_embeddings = instantiate_from_config(data.fe.embedding_parameters)
else:
self.expression_embeddings = None
# Setting up explicit Positional Encodings
if self.fc.max_input_length is None or (pos_enc in ("none", "NONE")):
self.position_embeddings = None
elif pos_enc == "BERT":
self.position_embeddings = nn.Embedding(self.fc.max_input_length + 1, d_model) # +1 cuz of CLS
elif pos_enc == "sincos":
self.position_embeddings = PositionalEncoding(d_model, max_len=self.fc.max_input_length + 1)
elif pos_enc is None:
self.position_embeddings = None
else:
raise ValueError("pos_enc can only be one of: `BERT`, `sincos`, `None`")
self.metadata_embeddings = instantiate_from_config(data.fc.embedding_parameters)
# Initialize the [CLS] token as a learnable parameter
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
# # Encoder - must be subsequently instantiated by child class
self.cell_sentence_model = None
def embed_inputs(self, inputs):
"""Embed inputs using the `Fc` reduce function."""
identity_inputs, expression_inputs = inputs["identity_inputs"], inputs["expression_inputs"]
input_embeds = self.fc.reduce(
identity_inputs,
self.gene_embeddings,
expression_inputs,
self.expression_embeddings,
self.metadata_embeddings,
)
batch_size = identity_inputs.size(0)
seq_length = input_embeds.size(1)
# Positional Encoding
if self.position_embeddings is not None:
if isinstance(self.position_embeddings, nn.Embedding):
# BERT-style learned embeddings
position_ids = torch.arange(
seq_length,
dtype=torch.long,
device=input_embeds.device,
).expand((batch_size, -1))
input_embeds += self.position_embeddings(position_ids)
else:
# Sinusoidal encoding
input_embeds = self.position_embeddings(input_embeds)
return input_embeds
def forward(self, inputs, attention_mask=None):
"""LM model.
Args:
inputs: This is either integers if IDs or bf16/fp32
floats for predefined embeddings
attention_mask: A tensor of shape [batchsize, seqlen] where 1/True
represents no attention and 0/False represents that attention should be used
Returns:
torch.tensor: The predicted outputs before cross entropy loss.
"""
input_embeds = self.embed_inputs(inputs)
# Encoder
outputs = self.cell_sentence_model(input_embeds, attention_mask)
return outputs
class Average(CellSentenceModel):
def __init__(
self,
data: CellRepresentation,
d_model: int,
pos_enc: str,
pooling: str,
):
if pooling != "mean_pooling":
raise ValueError("Please ensure that `pooling == 'mean_pooling'`")
if pos_enc is not None:
raise ValueError("Please ensure that `pos_enc is None`")
super().__init__(data, d_model=d_model, pos_enc=pos_enc, pooling=pooling)
self.cell_sentence_model = AverageEncoder()
class ExpressionWeightedSum(CellSentenceModel):
def __init__(
self,
data: CellRepresentation,
d_model: int,
pos_enc: str,
pooling: str,
):
if pooling != "mean_pooling":
raise ValueError("Please ensure that `pooling == 'mean_pooling'`")
if pos_enc is not None:
raise ValueError("Please ensure that `pos_enc is None`")
super().__init__(data, d_model=d_model, pos_enc=pos_enc, pooling=pooling)
self.cell_sentence_model = ExpressionWeightedSumEncoder()
def forward(self, inputs, attention_mask=None):
"""LM model.
Args:
inputs: This is either integers if IDs or bf16/fp32
floats for predefined embeddings
attention_mask: A tensor of shape [batchsize, seqlen] where 1/True
represents no attention and 0/False represents that attention should be used
Returns:
torch.tensor: The predicted outputs before cross entropy loss.
"""
expression_inputs = inputs["expression_inputs"]
input_embeds = self.embed_inputs(inputs)
# Encoder
outputs = self.cell_sentence_model(expression_inputs, input_embeds, attention_mask)
return outputs
class Transformer(CellSentenceModel):
def __init__(
self,
data: CellRepresentation,
d_model: int,
pos_enc: str,
pooling: str,
nhead: int,
hidden_dropout_prob: float,
hidden_act: str,
use_flash_attn: bool,
num_encoder_layers: int,
):
"""Heimdall transformer model.
Args:
data: Cell representation data object.
config: The transformer config.
.. code-block:: python
"""
super().__init__(data, d_model=d_model, pos_enc=pos_enc, pooling=pooling)
self.use_flash_attn = use_flash_attn
# # Encoder
self.cell_sentence_model = TransformerEncoder(
d_model=d_model,
nhead=nhead,
hidden_dropout_prob=hidden_dropout_prob,
use_flash_attn=use_flash_attn,
hidden_act=hidden_act,
num_encoder_layers=num_encoder_layers,
)
def forward(self, inputs, attention_mask=None):
"""LM model.
Args:
inputs: This is either integers if IDs or bf16/fp32
floats for predefined embeddings
attention_mask: A tensor of shape [batchsize, seqlen] where 1/True
represents no attention and 0/False represents that attention should be used
Returns:
torch.tensor: The predicted outputs before cross entropy loss.
"""
input_embeds = self.embed_inputs(inputs)
batch_size, seq_length, _ = input_embeds.size()
# Concatenate the CLS Token to both the attention mask and the input
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # Expand to match batch size
input_embeds = torch.cat([cls_tokens, input_embeds], dim=1)
if attention_mask is not None:
cls_attention = torch.zeros(
(batch_size, 1),
dtype=torch.bool,
device=attention_mask.device,
) # Shape: (batch_size, 1)
attention_mask = torch.cat([cls_attention, attention_mask], dim=1) # Shape: (batch_size, seq_len + 1)
# Encoder
outputs = self.cell_sentence_model(input_embeds, attention_mask)
return outputs
class AverageEncoder(nn.Module):
def forward(self, input_embeds, attention_mask):
# Encoder
# take the average of the encoder outputs across the sequence length dimension
# encoder_output = torch.mean(linear_transform, dim=1)
valid_mask = ~attention_mask
expanded_mask = valid_mask.unsqueeze(-1) # Add an extra dimension for broadcasting
# Mask the input_embeds
masked_embeds = input_embeds * expanded_mask
# Sum the valid (unmasked) embeddings along the sequence dimension
sum_embeds = masked_embeds.sum(dim=1)
valid_counts = expanded_mask.sum(dim=1) # Shape: [batch, 1]
valid_counts = valid_counts.clamp(min=1)
# Compute the average, taking into account only the valid values
masked_avg = sum_embeds / valid_counts
return masked_avg
class ExpressionWeightedSumEncoder(nn.Module):
"""Implementation of expression-weighted sum encoder used by GenePT-w."""
def forward(self, expression_inputs, input_embeds, attention_mask):
valid_mask = ~attention_mask
expanded_mask = valid_mask.unsqueeze(-1) # Add an extra dimension for broadcasting
expanded_expression_inputs = torch.unsqueeze(expression_inputs, dim=2)
# Mask the input_embeds
masked_embeds = input_embeds * expanded_mask
masked_expression_inputs = expanded_expression_inputs * expanded_mask
# Sum the valid (unmasked) embeddings along the sequence dimension
weighted_sum = masked_embeds.mul(masked_expression_inputs)
return weighted_sum
class TransformerEncoder(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
hidden_dropout_prob: float,
use_flash_attn: bool,
num_encoder_layers: int,
hidden_act: str = "gelu",
):
super().__init__()
self.use_flash_attn = use_flash_attn
if self.use_flash_attn:
from Heimdall.models._flash_attn import FlashTransformerEncoder
self.encoder = FlashTransformerEncoder(
d_model,
nhead,
num_encoder_layers,
dropout=hidden_dropout_prob,
activation=hidden_act,
)
else:
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=hidden_dropout_prob,
activation=hidden_act,
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
def forward(self, input_embeds, attention_mask=None):
return self.encoder(input_embeds, src_key_padding_mask=attention_mask)