from abc import ABC, abstractmethod
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Sequence
import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pandas.api.typing import NAType
from Heimdall.utils import check_states, conditional_print, pca_reduction
if TYPE_CHECKING:
from Heimdall.cell_representations import CellRepresentation
[docs]
class Fg(ABC):
"""Abstraction of the gene embedding mapping paradigm.
Args:
d_embedding: dimensionality of embedding for each gene entity
"""
def __init__(
self,
data: "CellRepresentation",
embedding_parameters: DictConfig,
d_embedding: int,
vocab_size: int,
pad_value: int = None,
mask_value: int = None,
frozen: bool = False,
rng: int | np.random.Generator = 0,
do_pca_reduction: bool = True,
):
self.data = data
self.d_embedding = d_embedding
self.embedding_parameters = OmegaConf.to_container(embedding_parameters, resolve=True)
self.vocab_size = vocab_size
self.pad_value = vocab_size - 2 if pad_value is None else pad_value
self.mask_value = vocab_size - 1 if mask_value is None else mask_value
self.frozen = frozen
self.rng = np.random.default_rng(rng)
self.do_pca_reduction = do_pca_reduction
[docs]
@abstractmethod
@check_states(adata=True)
def preprocess_embeddings(self, float_dtype: str = "float32"):
"""Preprocess gene embeddings and store them for use during model
inference.
Preprocessing may include anything from downloading gene embeddings from
a URL to generating embeddings from scratch.
Args:
float_dtype: dtype to be used for identity embedding state.
Returns:
Sets `self.gene_embeddings`.
Sets the following fields of `self.adata`:
`.var['identity_embedding_index']` : :class:`~numpy.ndarray` (shape `(self.adata.n_vars,)`)
Index of gene in embeddings.
`.var['identity_valid_mask']` : :class:`~numpy.ndarray` (shape `(self.adata.n_vars,)`)
Boolean mask indicating whether or not gene is mapped by this `Fg`.
"""
def __getitem__(self, gene_names: Sequence[str], return_mask: bool = False) -> Sequence[int | NAType]:
"""Get the indices of genes in the embedding array.
Must run `self.preprocess_embeddings()` before using this function.
Args:
gene_names: name of the gene as stored in `self.adata`.
Returns:
Index of gene in the embedding, or `pd.NA` if the gene has no mapping.
"""
embedding_indices = self.adata.var.loc[gene_names, "identity_embedding_index"]
valid_mask = self.adata.var.loc[gene_names, "identity_valid_mask"]
if (valid_mask.sum() != len(gene_names)) and not return_mask:
raise KeyError(
"At least one gene is not mapped in this `Fg`. "
"Please remove such genes from consideration in the `Fc`.",
)
if return_mask:
return embedding_indices, valid_mask
else:
return embedding_indices
@property
def identity_valid_mask(self):
return self.adata.var["identity_valid_mask"]
@identity_valid_mask.setter
def identity_valid_mask(self, val):
self.vocab_size -= self.adata.n_vars - np.sum(val)
self.adata.var["identity_valid_mask"] = val
[docs]
def prepare_embedding_parameters(self):
"""Replace config placeholders with values after preprocessing."""
args = self.embedding_parameters.get("args", {})
for key, value in args.items():
if value == "vocab_size":
value = self.vocab_size # <PAD> and <MASK> TODO: data.vocab_size
elif value == "gene_embeddings":
gene_embeddings = torch.tensor(self.gene_embeddings) # TODO: type is inherited from NDArray
pad_vector = torch.zeros(1, self.d_embedding)
mask_vector = torch.zeros(1, self.d_embedding)
value = torch.cat((gene_embeddings, pad_vector, mask_vector), dim=0)
self.pad_value = value.shape[0] - 2
self.mask_value = value.shape[0] - 1
else:
continue
self.embedding_parameters["args"][key] = value
[docs]
def load_from_cache(
self,
identity_embedding_index: NDArray,
identity_valid_mask: NDArray,
gene_embeddings: NDArray | None,
):
"""Load processed values from cache."""
# TODO: add tests
self.adata.var["identity_embedding_index"] = identity_embedding_index
self.identity_valid_mask = identity_valid_mask
self.gene_embeddings = gene_embeddings
self.prepare_embedding_parameters()
@property
def adata(self):
return self.data.adata
[docs]
class PretrainedFg(Fg, ABC):
"""Abstraction for pretrained `Fg`s that can be loaded from disk.
Args:
embedding_filepath: filepath from which to load pretrained embeddings
Raises:
ValueError: if `config.d_embedding` is larger than embedding dimensionality given in filepath.
"""
def __init__(
self,
data: "CellRepresentation",
# adata: ad.AnnData,
embedding_parameters: OmegaConf,
embedding_filepath: Optional[str | PathLike] = None,
**fg_kwargs,
):
super().__init__(data, embedding_parameters, **fg_kwargs)
self.embedding_filepath = Path(embedding_filepath)
[docs]
@abstractmethod
def load_embeddings(self) -> Dict[str, NDArray]:
"""Load the embeddings from disk and process into map.
Returns:
A mapping from gene names to embedding vectors.
"""
[docs]
@check_states(adata=True)
def preprocess_embeddings(self, float_dtype: str = "float32"):
embedding_map = self.load_embeddings()
first_embedding = next(iter(embedding_map.values()))
if len(first_embedding) < self.d_embedding:
raise ValueError(
f"Dimensionality of pretrained embeddings ({len(first_embedding)} is less than the embedding "
"dimensionality specified in the config ({self.d_embedding}). Please decrease the embedding"
"dimensionality to be compatible with the pretrained embeddings.",
)
if len(first_embedding) > self.d_embedding:
conditional_print(
f"> Warning, the `Fg` embedding dim {first_embedding.shape} is larger than the model "
f"dim {self.d_embedding}, truncation may occur.",
condition=self.data.verbose,
)
if self.do_pca_reduction:
original_embedding_filepath = self.embedding_filepath
self.embedding_filepath = (
original_embedding_filepath.parent
/ f"{original_embedding_filepath.stem}_reduced_{self.d_embedding}.pt"
)
if self.embedding_filepath.is_file():
embedding_map = self.load_embeddings()
conditional_print(
"> Loaded PCA-reduced `Fg` embeddings from cache.",
condition=self.data.verbose,
)
else:
embedding_map = pca_reduction(embedding_map, n_components=self.d_embedding)
torch.save(
{gene_name: torch.from_numpy(embedding) for gene_name, embedding in embedding_map.items()},
self.embedding_filepath,
)
conditional_print(
"> Used PCA to reduce `Fg` embeddings and cached for future use.",
condition=self.data.verbose,
)
self.embedding_filepath = original_embedding_filepath
valid_gene_names = list(embedding_map.keys())
valid_mask = pd.array(self.adata.var_names.isin(valid_gene_names))
num_mapped_genes = valid_mask.sum()
(valid_indices,) = np.nonzero(valid_mask)
index_map = valid_mask.astype(pd.Int64Dtype())
index_map[~valid_mask] = None
index_map[valid_indices] = np.arange(num_mapped_genes)
self.adata.var["identity_embedding_index"] = index_map
self.identity_valid_mask = valid_mask.to_numpy()
self.gene_embeddings = np.zeros((num_mapped_genes, self.d_embedding), dtype=float_dtype)
for gene_name in self.adata.var_names:
embedding_index = self.adata.var.loc[gene_name, "identity_embedding_index"]
if not pd.isna(embedding_index):
self.gene_embeddings[embedding_index] = embedding_map[gene_name][: self.d_embedding]
self.prepare_embedding_parameters()
conditional_print(
f"Found {len(valid_indices)} genes with mappings out of {len(self.adata.var_names)} genes.",
condition=self.data.verbose,
)
map_ratio = len(valid_indices) / len(self.adata.var_names)
if map_ratio < 0.5:
raise ValueError(
"Very few genes in the dataset are mapped by the `Fg`."
"Please check if the species is set correctly in the config.",
)
[docs]
class IdentityFg(Fg):
"""Identity mapping of gene names to embeddings.
This is the simplest possible Fg; it implies the use of learnable gene
embeddings that are initialized randomly, as opposed to the use of
pretrained embeddings.
"""
[docs]
@check_states(adata=True)
def preprocess_embeddings(self, float_dtype: str = "float32"):
self.gene_embeddings = None
self.adata.var["identity_embedding_index"] = np.arange(self.adata.n_vars)
self.identity_valid_mask = np.full(self.adata.n_vars, True)
self.prepare_embedding_parameters()
[docs]
class TorchTensorFg(PretrainedFg):
"""Mapping of gene names to pretrained embeddings stored as PyTorch
tensors."""
[docs]
def load_embeddings(self):
raw_gene_embedding_map = torch.load(self.embedding_filepath, weights_only=True)
raw_gene_embedding_map = {
gene_name: embedding.detach().cpu().numpy() for gene_name, embedding in raw_gene_embedding_map.items()
}
return raw_gene_embedding_map
[docs]
class CSVFg(PretrainedFg):
"""Mapping of gene names to pretrained Gene2Vec embeddings."""
[docs]
def load_embeddings(self):
raw_gene_embedding_dataframe = pd.read_csv(self.embedding_filepath, sep=r"\s+", header=None, index_col=0)
raw_gene_embedding_map = {
gene_name: raw_gene_embedding_dataframe.loc[gene_name].values
for gene_name in raw_gene_embedding_dataframe.index
}
return raw_gene_embedding_map
class Gene2VecFg(TorchTensorFg):
"""Mapping of gene names to pretrained Gene2VecFg embeddings."""
class GenePTFg(TorchTensorFg):
"""Mapping of gene names to pretrained GenePT embeddings."""
class ESM2Fg(TorchTensorFg):
"""Mapping of gene names to pretrained ESM2 embeddings."""
class HyenaDNAFg(TorchTensorFg):
"""Mapping of gene names to pretrained HyenaDNA embeddings."""