from abc import ABC, abstractmethod
import numpy as np
from numpy.typing import NDArray
from Heimdall.fc import Fc
[docs]
class Tailor(ABC):
def __init__(
self,
fc: Fc,
):
self.fc = fc
[docs]
@abstractmethod
def pad(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
"""Pad tokenization that is smaller than desired input length.
Args:
cell_tokenization: the stacked gene identity- and gene expression-based tokenization
dof a cell.
"""
(input_length,) = identity_inputs.shape
padding_args = {
"pad_width": ((0, self.fc.max_input_length - input_length)),
"mode": "constant",
"constant_values": (0, np.nan),
}
padded_identity_inputs = np.pad(
identity_inputs.astype(self.fc.float_dtype),
**padding_args,
)
padded_expression_inputs = np.pad(
expression_inputs.astype(self.fc.float_dtype),
**padding_args,
)
padded_identity_inputs[np.isnan(padded_identity_inputs).nonzero()] = self.fc.fg.pad_value
padded_expression_inputs[np.isnan(padded_expression_inputs).nonzero()] = self.fc.fe.pad_value
return padded_identity_inputs, padded_expression_inputs
[docs]
@abstractmethod
def limit(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
"""Limit tokenization that exceeds the desired input length.
Args:
cell_tokenization: the stacked gene identity- and gene expression-based tokenization
of a cell.
"""
identity_inputs = identity_inputs[: self.fc.max_input_length].astype(self.fc.float_dtype)
expression_inputs = expression_inputs[: self.fc.max_input_length].astype(self.fc.float_dtype)
return identity_inputs, expression_inputs
def __call__(self, identity_inputs: NDArray, expression_inputs: NDArray, gene_order: NDArray) -> NDArray:
(input_length,) = identity_inputs.shape
if input_length >= self.fc.max_input_length:
identity_inputs, expression_inputs = self.limit(identity_inputs, expression_inputs, gene_order)
# print(f"{identity_inputs=}")
# print(f"{expression_inputs=}")
(input_length,) = identity_inputs.shape
if input_length < self.fc.max_input_length:
identity_inputs, expression_inputs = self.pad(identity_inputs, expression_inputs, gene_order)
return identity_inputs, expression_inputs
[docs]
class ReorderTailor(Tailor):
[docs]
def limit(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
identity_inputs = identity_inputs[gene_order]
expression_inputs = expression_inputs[gene_order]
return super().limit(identity_inputs, expression_inputs, gene_order)
[docs]
def pad(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
identity_inputs = identity_inputs[gene_order]
expression_inputs = expression_inputs[gene_order]
return super().pad(identity_inputs, expression_inputs, gene_order)
[docs]
class WeightedResampleTailor(Tailor):
def __init__(self, fc: Fc, sample_size: int):
self.sample_size = sample_size
super().__init__(fc=fc)
[docs]
def weighted_resampling(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
"""Weighted sampling."""
weights = np.log1p(expression_inputs)
weights /= np.sum(weights)
resampled_indices = self.fc.rng.choice(
len(identity_inputs),
size=self.sample_size,
p=weights,
replace=True,
)
resampled_identity_inputs = identity_inputs[resampled_indices]
resampled_expression_inputs = expression_inputs[resampled_indices]
return resampled_identity_inputs, resampled_expression_inputs
[docs]
def limit(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
return super().limit(identity_inputs, expression_inputs, gene_order)
[docs]
def pad(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
return super().pad(identity_inputs, expression_inputs, gene_order)
def __call__(self, identity_inputs: NDArray, expression_inputs: NDArray, gene_order: NDArray) -> NDArray:
identity_inputs, expression_inputs = self.weighted_resampling(identity_inputs, expression_inputs, gene_order)
return super().__call__(identity_inputs, expression_inputs, gene_order)
[docs]
class ChromosomeBlockTailor(Tailor):
"""
Chromosome grouping without any resampling.
- Groups genes into blocks per chromosome with open/close tokens.
- Within each chromosome, genes are ordered according to `gene_order`.
- Uses Tailor.limit (default truncation) and Tailor.pad.
"""
def __init__(self, fc: Fc):
super().__init__(fc=fc)
def _group_by_chromosomes(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
# Number of original tokens (no sampling)
(input_length,) = identity_inputs.shape
# Pre-compute ranks (so lower rank = earlier)
# Matches your pattern: np.argsort(gene_order) -> ranks used to sort within chrom
gene_ranks = np.argsort(gene_order)
# Map each gene id to its chromosome
# Assumes `self.fc.chroms` is a pandas Series indexed by gene id
chrom_of_gene = self.fc.chroms.iloc[identity_inputs]
num_chromosomes = len(self.fc.shuffled_chromosomes)
raw_sequence_length = input_length + 2 * num_chromosomes # open + close for each chrom
grouped_gene_tokenization = np.full(
raw_sequence_length, self.fc.fg.pad_value, dtype=self.fc.float_dtype,
)
grouped_expression_tokenization = np.full(
raw_sequence_length, self.fc.fe.pad_value, dtype=self.fc.float_dtype,
)
sequence_index = 0
for chromosome in self.fc.shuffled_chromosomes:
# indices of genes belonging to this chromosome
chrom_idx = np.where(chrom_of_gene == chromosome)[0]
# open token for this chromosome
placeholder_id = -(chromosome + self.fc.chrom_token_offset + 1)
grouped_gene_tokenization[sequence_index] = placeholder_id
grouped_expression_tokenization[sequence_index] = placeholder_id
sequence_index += 1
# order genes within the chromosome using ranks derived from gene_order
if chrom_idx.size > 0:
chrom_gene_ranks = gene_ranks[chrom_idx]
chrom_order = np.argsort(chrom_gene_ranks)
chrom_genes = identity_inputs[chrom_idx][chrom_order]
chrom_exprs = expression_inputs[chrom_idx][chrom_order]
n = chrom_genes.shape[0]
grouped_gene_tokenization[sequence_index : sequence_index + n] = chrom_genes
grouped_expression_tokenization[sequence_index : sequence_index + n] = chrom_exprs
sequence_index += n
# close token
grouped_gene_tokenization[sequence_index] = -self.fc.chrom_token_offset
grouped_expression_tokenization[sequence_index] = -self.fc.chrom_token_offset
sequence_index += 1
# (sequence_index should equal raw_sequence_length)
return grouped_gene_tokenization, grouped_expression_tokenization
# Use default truncation (like ReorderTailor)
[docs]
def limit(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
return super().limit(identity_inputs, expression_inputs, gene_order)
# Use default padding (like ReorderTailor)
[docs]
def pad(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
return super().pad(identity_inputs, expression_inputs, gene_order)
def __call__(
self,
identity_inputs: NDArray,
expression_inputs: NDArray,
gene_order: NDArray,
) -> tuple[NDArray, NDArray]:
# Build chromosome-blocked sequence (no resampling)
identity_inputs, expression_inputs = self._group_by_chromosomes(
identity_inputs, expression_inputs, gene_order,
)
# Then apply the standard Tailor length control (truncate/pad)
return super().__call__(identity_inputs, expression_inputs, gene_order)