Source code for Heimdall.order
from abc import ABC, abstractmethod
import numpy as np
from numpy.typing import NDArray
from Heimdall.fc import Fc
[docs]
class Order(ABC):
def __init__(
self,
fc: Fc,
):
self.fc = fc
@abstractmethod
def __call__(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
) -> NDArray:
"""Order cell tokens using metadata.
Gene tokens can be reordered based on e.g. expression level, chromosome position, etc.
Args:
cell_tokenization: the stacked gene identity- and gene expression-based tokenization
of a cell.
"""
[docs]
class ExpressionOrder(Order):
def __call__(self, identity_inputs: NDArray, expression_inputs: NDArray) -> NDArray:
"""Order cell tokens using metadata.
Gene tokens are reordered based on expression level.
Args:
cell_tokenization: the stacked gene identity- and gene expression-based tokenization
of a cell.
"""
if "medians" in self.fc.adata.var:
expression_inputs = expression_inputs - self.fc.adata.var["medians"].iloc[identity_inputs].values
# Sort non-zero values in descending order
gene_order = np.argsort(expression_inputs)[::-1] # Indices for sorting descending
return gene_order
[docs]
class RandomOrder(Order):
def __call__(self, identity_inputs: NDArray, expression_inputs: NDArray) -> NDArray:
# TODO: consider cleaning up sampling (just sample all nonzero and all zero, then concat
(nonzero_indices,) = np.where(expression_inputs != 0)
(zero_indices,) = np.where(expression_inputs == 0)
# First: sample/reorder nonzero expression tokens
# num_nonzero_to_sample = min(len(nonzero_indices), self.fc.max_input_length)
num_nonzero = len(nonzero_indices)
num_zero = len(zero_indices)
# selected_nonzero = self.fc.rng.choice(nonzero_indices, num_nonzero_to_sample, replace=False)
selected_nonzero = self.fc.rng.choice(nonzero_indices, num_nonzero, replace=False)
# If needed: sample zero-expression tokens to fill up
# num_remaining = self.fc.max_input_length - num_nonzero_to_sample
# if num_remaining > 0:
if num_zero > 0:
selected_zero = self.fc.rng.choice(zero_indices, num_zero, replace=False)
gene_order = np.concatenate([selected_nonzero, selected_zero])
else:
gene_order = selected_nonzero
# Optionally shuffle to avoid position bias, but we dont need to because the gene ids are the position
# self.rng.shuffle(final_indices)
return gene_order
[docs]
class ChromosomeOrder(Order):
def __call__(self, identity_inputs: NDArray, expression_inputs: NDArray) -> NDArray:
"""Order cell tokens using metadata.
Gene tokens are reordered based on chromosome location.
Args:
cell_tokenization: the stacked gene identity- and gene expression-based tokenization
of a cell.
"""
choosen_chrom = self.fc.chroms.iloc[identity_inputs]
unique_chromosomes = np.unique(choosen_chrom)
self.fc.shuffled_chromosomes = self.fc.rng.permutation(unique_chromosomes)
gene_order = np.zeros(len(identity_inputs), dtype=np.int32)
for chromosome in self.fc.shuffled_chromosomes:
(chromosome_index,) = np.where(choosen_chrom == chromosome)
sort_by_start = np.argsort(
self.fc.starts[chromosome_index],
) # start chromosome_indexations for this chromsome
gene_order[chromosome_index] = chromosome_index[sort_by_start]
return gene_order