Source code for phasegen.state_space

"""
State space classes and utilities. The two main state space classes are
:class:`LineageCountingStateSpace` and :class:`BlockCountingStateSpace`.
"""

import logging
import time
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import product
from typing import List, Tuple, Dict, Callable, cast

import numpy as np
from tqdm import tqdm

from .coalescent_models import CoalescentModel, StandardCoalescent, BetaCoalescent, DiracCoalescent
from .settings import Settings
from .demography import Epoch
from .lineage import LineageConfig
from .locus import LocusConfig
from .state_space_numba import HAS_NUMBA, build_rate_matrix
from .state_space_old import StateSpace as OldStateSpace, LineageCountingStateSpace as OldLineageCountingStateSpace, \
    BlockCountingStateSpace as OldBlockCountingStateSpace

logger = logging.getLogger('phasegen')


def _numba_model_params(model: CoalescentModel) -> Tuple[int, float, float, float]:
    """
    Pack a coalescent model into ``(model_id, alpha, psi, c)`` for the numba kernels (0 standard, 1 beta, 2 dirac).
    """
    if isinstance(model, BetaCoalescent):
        return 1, model.alpha, 0.0, 0.0

    if isinstance(model, DiracCoalescent):
        return 2, 0.0, model.psi, model.c

    return 0, 0.0, 0.0, 0.0


[docs] class StateSpace(ABC): """ State space. """
[docs] def __init__( self, lineage_config: LineageConfig, locus_config: LocusConfig = None, model: CoalescentModel = None, epoch: Epoch = None ): """ Create a rate matrix. :param lineage_config: Population configuration. :param locus_config: Locus configuration. One locus is used by default. :param model: Coalescent model. By default, the standard coalescent is used. :param epoch: The epoch. """ if locus_config is None: locus_config = LocusConfig() if model is None: model = StandardCoalescent() if epoch is None: epoch = Epoch() #: Logger self._logger = logger.getChild(self.__class__.__name__) #: Coalescent model self.model: CoalescentModel = model #: Population configuration self.lineage_config: LineageConfig = lineage_config #: Locus configuration self.locus_config: LocusConfig = locus_config #: Epoch self.epoch: Epoch = epoch #: Cached rate matrices self._cache: Dict[Epoch, Tuple[Dict[Tuple['State', 'State'], Tuple[float, str]], List['State']]] = {} # time in seconds to compute original rate matrix self.time: float | None = None
@cached_property def states(self) -> List['State']: """ The states. """ start = time.time() if self._use_numba(): states, S = self._construct_numba() # the states are epoch-independent, but prime the current epoch's rate matrix to avoid rebuilding it self.__dict__.setdefault('S', S) else: # get all possible transitions transitions, states = self.get_transitions() # cache rate matrix if specified if Settings.cache_epochs: self._cache[self.epoch] = (transitions, states) # record time to compute rate matrix self.time = time.time() - start return states @cached_property def lineages(self) -> np.ndarray: """ The lineage configurations. Each configuration describes the lineages per block, deme and locus, i.e., ``[[[a_ijk]]]`` for block ``i``, deme ``j`` and locus ``k``. """ return np.array([s.lineages for s in self.states]) @cached_property def linked(self) -> np.ndarray: """ The linked lineages per block, deme and locus. :return: """ return np.array([s.linked for s in self.states]) @cached_property def unlinked(self) -> np.ndarray: """ Unlinked lineages. """ return self.lineages - self.linked @abstractmethod def _get_old(self) -> OldStateSpace: """ Get the old state space. """ pass def _get_old_ordering(self) -> List[int]: """ Get the ordering of the states in the old state space relative to the new state space. :return: Ordering of the states in the old state space. """ old = self._get_old() # reorder the states of s2 to match s1 return cast(List[int], [ np.where(((old.states == self.lineages[i]) & (old.linked == self.linked[i])).all(axis=(1, 2, 3)))[0][0] for i in range(self.k) ]) @staticmethod def _get_partitions(n: int, k: int) -> List[List[int]]: """ Find all vectors of length `k` with non-negative integers that sum to `n`. :param n: The sum. :param k: The length of the vectors. :return: All vectors of length `k` with non-negative integers that sum to `n`. """ if k == 0: return [[]] if k == 1: return [[n]] vectors = [] for i in range(n + 1): for vector in StateSpace._get_partitions(n - i, k - 1): vectors.append(vector + [i]) return vectors
[docs] def get_transitions(self) -> Tuple[Dict[Tuple['State', 'State'], Tuple[float, str]], List['State']]: """ Get all possible transitions from the given state. :return: All possible transitions from the given state. """ sources = [self._get_initial()] transitions = {} visited = [] # number of states and transitions i, j = 0, 0 # backward compatibility pbar = tqdm(desc=f'{self.__class__.__name__}: transitions', disable=not Settings.use_pbar) while True: targets_new = {} for source in sources: # skip if source has been visited already if source in visited: continue # get all possible transitions from source targets = self.transition.transit(source) # add visited source state visited += [source] # add transitions to dictionary for target, transition in targets.items(): transitions[(source, target)] = transition j += 1 pbar.update(1) # add targets to new targets targets_new |= targets # increment state counter i += 1 if i in [1000, 5000, 25000]: levels = {1000: 'slow', 5000: 'very slow', 25000: 'extremely slow'} self._logger.warning( f'State space size exceeds {i} states. Computations may be {levels[i]}.' ) # break if no more targets if len(targets_new) == 0: break # take new targets as source states sources = tuple(targets_new.keys()) pbar.set_description_str(f'{self.__class__.__name__}: ({i} states, {j} transitions)') pbar.close() # warn if state space is large if (k := len(visited)) > 400: self._logger.warning(f'State space is large ({k} states, {len(transitions)} transitions). ' f'Note that the computation time increases ' f'exponentially with the number of states.') return transitions, visited
@cached_property def e(self) -> np.ndarray: """ Vector with ones of size ``k``. """ return np.ones(self.k) @cached_property def S(self) -> np.ndarray: """ Intensity matrix. """ return self._get_rate_matrix() @cached_property def alpha(self) -> np.ndarray: """ Initial state vector. """ pops = self.lineage_config._get_initial_states(self) loci = self.locus_config._get_initial_states(self) # combine initial states alpha = pops * loci # return normalized vector # normalization ensures that the initial state vector is a probability distribution # as we may have multiple initial states return alpha / alpha.sum() @cached_property def k(self) -> int: """ Number of states. """ return len(self.states) @cached_property def transition(self) -> 'Transition': """ Transition. """ return Transition(self)
[docs] def update_epoch(self, epoch: Epoch): """ Update the epoch. :param epoch: Epoch. :return: State space. """ # only remove cached properties if epoch has changed if self.epoch != epoch: # update S by rescaling if already cached, provided there is only one population and one locus if ( self.lineage_config.n_pops == 1 and self.locus_config.n == 1 and 'S' in self.__dict__ ): self.S *= self._get_scaling_factor(self.epoch, epoch) else: self.drop_S() self.epoch = epoch
def _get_scaling_factor(self, epoch_prev: Epoch, epoch_next: Epoch) -> float: """ Get the scaling factor for the rate matrix when changing epochs. :param epoch_prev: Previous epoch. :param epoch_next: Next epoch. :return: Scaling factor. """ pop_prev = epoch_prev.pop_sizes[epoch_prev.pop_names[0]] pop_next = epoch_next.pop_sizes[epoch_next.pop_names[0]] return self.model._get_timescale(pop_prev) / self.model._get_timescale(pop_next) def __eq__(self, other): """ Check if two state spaces are equal. We do not check for equivalence of the epochs as we can update the epoch of a state space dynamically. :param other: Other state space :return: Whether the two state spaces are equal """ return ( self.__class__ == other.__class__ and self.lineage_config == other.lineage_config and self.locus_config == other.locus_config and self.model == other.model )
[docs] def drop_S(self): """ Drop the current rate matrix. """ try: # noinspection all del self.S except AttributeError: pass
[docs] def drop_cache(self): """ Drop the rate matrix cache and current rate matrix. """ self.drop_S() self._cache = {}
@abstractmethod def _get_initial(self): """ Get the initial state. """ pass def _is_absorbing(self, state: 'State') -> bool: """ Whether the given state is absorbing. By default this is the single-process absorbing condition (a single remaining lineage); state spaces with a different notion of absorption (e.g. two loci, which are absorbed once both have reached their MRCA) override this. :param state: State. :return: Whether the state is absorbing. """ return state.is_absorbing() def _use_numba(self) -> bool: """ Whether numba-accelerated construction applies: numba is available and enabled, there is a single locus, and the state space is one of the supported types (the 2-locus recombination path stays on the Python construction). """ if not (HAS_NUMBA and Settings.use_numba): return False # the single-locus state spaces (numbered 0/1); the two-locus space is handled separately below if self.locus_config.n == 1 and type(self) in ( LineageCountingStateSpace, BlockCountingStateSpace, JointBlockCountingStateSpace): return True # the two-locus block-counting state space (recombination) return type(self) is TwoLocusBlockCountingStateSpace def _numba_kind(self) -> int: """ Kernel selector: 0 lineage-counting, 1 block-/joint-counting, 2 two-locus block-counting. """ return 1 def _numba_recombination(self) -> Tuple[float, np.ndarray, np.ndarray]: """ Recombination parameters for the kernel: ``(recombination_rate, recomb0, recomb1)`` where ``recomb_l[b]`` is the block index that block ``b`` contributes to locus ``l`` when it recombines. Only the two-locus state space uses these; by default there is no recombination. """ return 0.0, None, None def _numba_block_vectors(self, n_blocks: int) -> np.ndarray: """ Block labels passed to the kernel; for block-counting, block ``i`` represents lineages subtending ``i + 1`` samples. """ return np.arange(1, n_blocks + 1, dtype=np.int64).reshape(-1, 1) def _construct_numba(self) -> Tuple[List['State'], np.ndarray]: """ Build the states and rate matrix for the current epoch via the numba kernel. :return: The states (in kernel discovery order) and the dense intensity matrix. """ init = self._get_initial() n_demes = init.lineages.shape[1] n_blocks = init.lineages.shape[2] pops = self.lineage_config.pop_names mig = np.zeros((n_demes, n_demes)) for i, a in enumerate(pops): for j, b in enumerate(pops): if i != j: mig[i, j] = self.epoch.migration_rates[(a, b)] timescales = np.array([self.model._get_timescale(self.epoch.pop_sizes[p]) for p in pops]) model_id, alpha, psi, c = _numba_model_params(self.model) recomb_rate, recomb0, recomb1 = self._numba_recombination() rows, S = build_rate_matrix( initial=init.lineages.reshape(-1), kind=self._numba_kind(), n_demes=n_demes, n_blocks=n_blocks, mig=mig, timescales=timescales, model_id=model_id, alpha=alpha, psi=psi, c=c, block_vectors=self._numba_block_vectors(n_blocks), recomb_rate=recomb_rate, recomb0=recomb0, recomb1=recomb1, ) lin_shape = init.lineages.shape lin_dtype = init.lineages.dtype linked = init.linked # all-zero for a single locus states = [State((row.reshape(lin_shape).astype(lin_dtype), linked.copy())) for row in rows] return states, S def _get_rate_matrix(self) -> np.ndarray: """ Get the rate matrix. TODO don’t compute transitions twice for disabled caching :return: The rate matrix. """ if self._use_numba(): states, S = self._construct_numba() self.__dict__.setdefault('states', states) return S # check if epoch is in cache if Settings.cache_epochs and self.epoch in self._cache: transitions, states = self._cache[self.epoch] else: # get all possible transitions transitions, states = self.get_transitions() # cache rate matrix if specified if Settings.cache_epochs: self._cache[self.epoch] = (transitions, states) return self._graph_to_matrix(transitions, states) @staticmethod def _graph_to_matrix( transitions: Dict[Tuple['State', 'State'], Tuple[float, str]], states: List['State'] ) -> np.ndarray: """ Convert transition graph to rate matrix. :param transitions: Transitions. :param states: States. :return: Rate matrix. """ S = np.zeros((len(states), len(states))) # order of original states ordering = {s: i for i, s in enumerate(states)} # fill rate matrix for (source, target), transition in transitions.items(): S[ordering[source], ordering[target]] = transition[0] # fill diagonal with negative sum of row S[np.diag_indices_from(S)] = -np.sum(S, axis=1) return S
[docs] def get_sparsity(self) -> float: """ Get the sparsity of the rate matrix. :return: The sparsity. """ return 1 - np.count_nonzero(self.S) / self.S.size
def _get_color_state(self, i: int) -> str: """ Get color of the state indexed by `i`. """ if self.states[i].is_absorbing(): return '#f1807e' if self.alpha[i] > 0: return 'lightgreen' return 'lightblue'
[docs] def plot_rates( self, file: str, view: bool = True, cleanup: bool = False, dpi: int = 400, ratio: float = 0.6, background_color: str = 'white', extension: str = 'png', format_state: Callable[[np.array], str] = None, format_transition: Callable[['Transition'], str] = None ): """ Plot the rate matrix using graphviz. Note that graphviz must be installed which is an external dependency. :param file: File to save plot to. :param view: Whether to view the plot. :param cleanup: Whether to remove the source file. :param dpi: Dots per inch. :param ratio: Aspect ratio. :param background_color: Background color. :param extension: File format. :param format_state: Function to format state with state array as argument. :param format_transition: Function to format transition with transition as argument. """ import graphviz if format_state is None: def format_state(s: Tuple[np.ndarray, np.ndarray]) -> str: """ Format state. :param s: State. :return: Formatted state. """ return str(s[0]).replace('\n', '') + '\n' + str(s[1]).replace('\n', '') if format_transition is None: def format_transition(rate: float, kind: str) -> str: """ Format transition. :param rate: Rate. :param kind: Kind. :return: Formatted transition. """ return f' {kind}: ' + '{:.2f}'.format(rate).rstrip('0').rstrip('.') graph = graphviz.Digraph() # add nodes for i, state in enumerate(self.states): graph.node( name=format_state(state.data), fillcolor=self._get_color_state(i), style='filled' ) transitions, _ = self.get_transitions() # add non-zero edges for (source, target), transition in transitions.items(): if not source.is_absorbing(): graph.edge( tail_name=format_state(source.data), head_name=format_state(target.data), label=format_transition(*transition), color=Transition._colors[transition[1]], fontcolor=Transition._colors[transition[1]] ) graph.graph_attr['dpi'] = str(dpi) graph.graph_attr['ratio'] = str(ratio) graph.graph_attr['bgcolor'] = background_color graph.render( filename=file, view=view, cleanup=cleanup, format=extension )
[docs] class LineageCountingStateSpace(StateSpace): """ Default rate matrix where there is one state per number of lineages for each deme and locus. """ def _get_initial(self) -> 'State': """ Get the initial state. """ data = tuple(np.zeros((self.locus_config.n, self.lineage_config.n_pops, 1), dtype=int) for _ in range(2)) data[0][:, 0, 0] = self.lineage_config.n return State(data) def _numba_kind(self) -> int: """ Lineage-counting kernel. """ return 0 def _numba_block_vectors(self, n_blocks: int) -> np.ndarray: """ Lineage counting has a single block; the label is unused by the lineage kernel. """ return np.array([[1]], dtype=np.int64) def _get_old(self) -> OldLineageCountingStateSpace: """ Get the old state space. """ return OldLineageCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.epoch )
[docs] class BlockCountingStateSpace(StateSpace): r""" Rate matrix for block-counting state space where there is one state per sample configuration: A block-counting state is a vector of length ``n`` where each element represents the number of lineages subtending ``i`` lineages in the coalescent tree. .. math:: (a_1,...,a_n) \in \mathbb{Z}_+^n : \sum_{i=1}^{n} i a_i = n. per deme and per locus. This state space can distinguish between different tree topologies and is thus used when computing statistics based on the SFS. """
[docs] def __init__( self, lineage_config: LineageConfig, locus_config: LocusConfig = None, model: CoalescentModel = None, epoch: Epoch = None ): """ Create a rate matrix. :param lineage_config: Population configuration. :param locus_config: Locus configuration. One locus is used by default. :param model: Coalescent model. By default, the standard coalescent is used. :param epoch: The epoch """ # currently only one locus is supported, due to a very complex state space for multiple loci if locus_config is not None and locus_config.n > 1: raise NotImplementedError('Block-counting state space only supports one locus.') super().__init__( lineage_config=lineage_config, locus_config=locus_config, model=model, epoch=epoch )
def _get_initial(self) -> 'State': """ Get the initial state. """ data = tuple( np.zeros((self.locus_config.n, self.lineage_config.n_pops, self.lineage_config.n), dtype=int) for _ in range(2) ) data[0][:, 0, 0] = self.lineage_config.n return State(data) @staticmethod def _traverse( start_state: int, start_p: float, S: np.ndarray, absorbing_states: np.ndarray, lineage_counts: np.ndarray, ) -> np.ndarray: """ Calculate the probabilities of being in each state conditioned on the number of lineages. :param start_state: Index of the starting state. :param start_p: Probability of starting in the starting state. :param S: Transition matrix. :param absorbing_states: Indices of absorbing states. :param lineage_counts: Number of lineages in each state. :return: State probabilities conditioned on the number of lineages. """ probs = np.zeros(S.shape[0]) probs[start_state] = start_p state_indices = np.arange(lineage_counts.shape[0]) # descending lineage counts unique_counts = sorted(set(lineage_counts), reverse=True) for k in unique_counts: # iterate over states with k lineages for i in state_indices[lineage_counts == k]: if i in absorbing_states or S[i, i] == 0: continue trans_probs = S[i] / -S[i, i] for j in np.where(trans_probs > 0)[0]: probs[j] += probs[i] * trans_probs[j] return probs @cached_property def _state_probs(self) -> np.ndarray: """ Get state probabilities conditioned on the number of lineages. This can be used to flatten the block-counting state space to a lineage-counting state space by weighting the lineages by the probabilities of being in each of the corresponding block-counting states. This only works for one-population, one-locus state spaces under the standard coalescent, or MMCs provided there is only one epoch, and we accumulate until absorption. :return: State probabilities conditioned on the number of lineages. """ self._logger.debug('Calculating state probabilities conditioned on the number of lineages.') absorbing_states = np.where([s.is_absorbing() for s in self.states])[0] probs = np.zeros(self.k) lineage_counts = np.array([s.lineages.sum() for s in self.states]) state_indices = np.arange(lineage_counts.shape[0]) for i, p in zip(state_indices[self.alpha > 0], self.alpha[self.alpha > 0]): probs += self._traverse(i, p, self.S, absorbing_states, lineage_counts) return probs def _get_old(self) -> OldBlockCountingStateSpace: """ Get the old state space. """ return OldBlockCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.epoch )
[docs] class JointBlockCountingStateSpace(StateSpace): r""" Rate matrix for the joint (multi-population) site-frequency spectrum. This is a generalization of :class:`BlockCountingStateSpace`. In the block-counting state space a block is a single *size class* ``i`` (the number of sampled lineages a lineage subtends), which discards the information of *which* population those descendants came from. The joint SFS bins mutations by their allele frequency in each population simultaneously, so each block must instead be the **descendant vector** .. math:: v = (v_0, \dots, v_{P-1}), \quad 0 \le v_p \le n_p, \quad 1 \le \sum_p v_p \le n, i.e. the number of descendants a lineage subtends from each population ``p`` (its "deme of origin" composition). A state then counts, per locus and per *current* deme of residence, how many lineages carry each descendant vector. .. note:: Splitting each size class by deme of origin both enlarges the block axis (``n_blocks = prod(n_p + 1) - 1``) and grows the number of reachable states combinatorially, much faster than the single-population block-counting space. This state space is therefore only practical for small per-population sample sizes. Only one locus is supported. """
[docs] def __init__( self, lineage_config: LineageConfig, locus_config: LocusConfig = None, model: CoalescentModel = None, epoch: Epoch = None ): """ Create a rate matrix. :param lineage_config: Population configuration. :param locus_config: Locus configuration. One locus is used by default. :param model: Coalescent model. By default, the standard coalescent is used. :param epoch: The epoch. """ # the joint state space tracks descendant vectors which do not extend to multiple loci if locus_config is not None and locus_config.n > 1: raise NotImplementedError('Joint block-counting state space only supports one locus.') super().__init__( lineage_config=lineage_config, locus_config=locus_config, model=model, epoch=epoch )
@cached_property def block_configs(self) -> List[Tuple[int, ...]]: """ Ordered list of all descendant vectors (block types). A descendant vector ``(v_0,...,v_{P-1})`` has ``0 <= v_p <= n_p`` and at least one non-zero entry. """ sizes = [int(n_p) for n_p in self.lineage_config.lineages] return [c for c in product(*[range(s + 1) for s in sizes]) if sum(c) >= 1] @cached_property def block_index(self) -> Dict[Tuple[int, ...], int]: """ Mapping from descendant vector to its block index. """ return {c: i for i, c in enumerate(self.block_configs)} @cached_property def block_vectors(self) -> np.ndarray: """ Descendant vectors as an integer array of shape ``(n_blocks, n_pops)``. """ return np.array(self.block_configs, dtype=int) def _numba_block_vectors(self, n_blocks: int) -> np.ndarray: """ Block labels are the descendant vectors; the merged block of a coalescence is the one whose descendant vector equals the sum of the merging vectors. """ return np.asarray(self.block_vectors, dtype=np.int64) @cached_property def n_blocks(self) -> int: """ Number of block types (descendant vectors). """ return len(self.block_configs) def _get_initial(self) -> 'State': """ Get the initial state. Each of the ``n_p`` lineages sampled from population ``p`` starts in deme ``p`` with descendant vector ``e_p`` (the unit vector subtending a single sample from population ``p``). """ data = tuple( np.zeros((self.locus_config.n, self.lineage_config.n_pops, self.n_blocks), dtype=int) for _ in range(2) ) for p in range(self.lineage_config.n_pops): if self.lineage_config.lineages[p] > 0: e_p = tuple(1 if q == p else 0 for q in range(self.lineage_config.n_pops)) data[0][0, p, self.block_index[e_p]] = self.lineage_config.lineages[p] return State(data) @cached_property def alpha(self) -> np.ndarray: """ Initial state vector. There is a single, origin-aware initial state (see :meth:`_get_initial`), so this is its indicator vector. """ initial = self._get_initial() return np.array([s == initial for s in self.states], dtype=float) def _get_old(self) -> OldStateSpace: """ The joint block-counting state space has no legacy equivalent. """ raise NotImplementedError('The joint block-counting state space has no legacy equivalent.')
[docs] class TwoLocusBlockCountingStateSpace(JointBlockCountingStateSpace): r""" Block-counting state space for two loci separated by recombination, used to compute the two-locus SFS. Each physical ancestral lineage is described by a vector ``(a_0, a_1)`` giving the number of sampled lineages it subtends at locus 0 and locus 1 (``a_l = 0`` meaning it is not ancestral at locus ``l``). This is the same representation as the joint (multi-population) state space, with **locus** playing the role of **population**, so coalescence is again vector addition (and reuses the model-agnostic merger rates, supporting Beta/Dirac too). The new ingredient is **recombination**: a linked lineage ``(a_0, a_1)`` with ``a_0 > 0`` and ``a_1 > 0`` splits into ``(a_0, 0)`` and ``(0, a_1)`` at rate ``r`` per linked lineage. The process is absorbed once both loci have reached their MRCA. A single population is currently supported (no migration), so the state shape collapses to ``(1, 1, n_blocks)``; linkage is encoded in the block vector itself, so the ``linked`` array is unused. .. note:: ``n_blocks = (n + 1)^2 - 1`` and the number of reachable states grows quickly, so this is only practical for small sample sizes. """
[docs] def __init__( self, lineage_config: LineageConfig, locus_config: LocusConfig = None, model: CoalescentModel = None, epoch: Epoch = None ): """ Create the two-locus block-counting state space. :param lineage_config: Population configuration (a single population is currently supported). :param locus_config: Locus configuration; must specify exactly two loci. :param model: Coalescent model. By default, the standard coalescent is used. :param epoch: The epoch. """ if locus_config is None or locus_config.n != 2: raise ValueError('The two-locus block-counting state space requires exactly two loci.') if lineage_config.n_pops != 1: raise NotImplementedError('The two-locus block-counting state space currently supports a single ' 'population (no migration).') # bypass JointBlockCountingStateSpace.__init__ (which rejects more than one locus) StateSpace.__init__(self, lineage_config=lineage_config, locus_config=locus_config, model=model, epoch=epoch)
@cached_property def block_configs(self) -> List[Tuple[int, ...]]: """ Ordered list of all two-locus descendant vectors ``(a_0, a_1)`` with ``0 <= a_l <= n`` and at least one non-zero entry. """ n = int(self.lineage_config.n) return [c for c in product(range(n + 1), range(n + 1)) if sum(c) >= 1] def _get_initial(self) -> 'State': """ Get the initial state: ``n - n_unlinked`` lineages of type ``(1, 1)`` (linked across both loci) plus, for each of the ``n_unlinked`` initially unlinked samples, a ``(1, 0)`` and a ``(0, 1)`` lineage. """ n = int(self.lineage_config.n) n_unlinked = int(self.locus_config.n_unlinked) n_linked = max(n - n_unlinked, 0) data = tuple(np.zeros((1, 1, self.n_blocks), dtype=int) for _ in range(2)) if n_linked > 0: data[0][0, 0, self.block_index[(1, 1)]] = n_linked if n_unlinked > 0: data[0][0, 0, self.block_index[(1, 0)]] += n_unlinked data[0][0, 0, self.block_index[(0, 1)]] += n_unlinked return State(data) def _is_absorbing(self, state: 'State') -> bool: """ A two-locus state is absorbing once both loci have reached their MRCA, i.e. exactly one lineage carries ancestral material at locus 0 and exactly one carries it at locus 1 (covering both the single linked grand-MRCA ``(n, n)`` and the unlinked pair ``(n, 0) + (0, n)``). """ lineages = state.lineages[0, 0] return all(int(lineages[self.block_vectors[:, locus] > 0].sum()) == 1 for locus in range(2)) def _numba_kind(self) -> int: """ Two-locus block-counting kernel. """ return 2 def _numba_recombination(self) -> Tuple[float, np.ndarray, np.ndarray]: """ Recombination split maps: for each linked block ``(a_0, a_1)`` the indices of ``(a_0, 0)`` and ``(0, a_1)``. """ recomb0 = np.zeros(self.n_blocks, dtype=np.int64) recomb1 = np.zeros(self.n_blocks, dtype=np.int64) for b in range(self.n_blocks): a0, a1 = int(self.block_vectors[b, 0]), int(self.block_vectors[b, 1]) if a0 > 0 and a1 > 0: recomb0[b] = self.block_index[(a0, 0)] recomb1[b] = self.block_index[(0, a1)] return self.locus_config.recombination_rate, recomb0, recomb1
[docs] class Transition: """ Class representing a transition between two states. """ #: Colors for different types of transitions _colors: Dict[str, str] = { 'recombination': 'orange', 'coalescence': 'darkgreen', 'locus_coalescence': 'darkgreen', 'linked_coalescence': 'darkgreen', 'unlinked_coalescence': 'darkgreen', 'mixed_coalescence': 'darkgreen', 'unlinked_coalescence+mixed_coalescence': 'darkgreen', 'mixed_coalescence+unlinked_coalescence': 'darkgreen', 'linked_migration': 'blue', 'unlinked_migration': 'blue', 'migration': 'blue', 'invalid': 'red' }
[docs] def __init__( self, state_space: StateSpace ): """ Initialize a transition. :param state_space: State space. """ #: State space. self.state_space: StateSpace = state_space
[docs] def transit(self, source: 'State') -> Dict['State', Tuple[float, str]]: """ Get all possible target states from the given source state. :param source: Source state. :return: All possible target states. """ targets: Dict['State', Tuple[float, str]] = {} targets |= self.migrate(source) if self.state_space._is_absorbing(source): return targets targets |= self.coalesce(source) targets |= self.recombine(source) return targets
[docs] @staticmethod def add_target(targets: Dict['State', Tuple[float, str]], target: 'State', rate: float, kind: str): """ Add a target state to the list of targets. :param targets: Dictionary of target states. :param target: New target state. :param rate: Rate of the transition. :param kind: Kind of the transition. """ if target in targets: targets[target] = (targets[target][0] + rate, targets[target][1] + '+' + kind) else: targets[target] = (rate, kind)
[docs] def coalesce(self, source: 'State') -> Dict['State', Tuple[float, str]]: """ Get all possible coalescent transitions from the given state. :param source: Source state. :return: All possible coalescent transitions from the given state. """ if isinstance(self.state_space, JointBlockCountingStateSpace): return self._coalesce_joint(source) targets: Dict['State', Tuple[float, str]] = {} pop_sizes = [self.state_space.epoch.pop_sizes[pop] for pop in self.state_space.lineage_config.pop_names] if source.n_loci == 1: locus = 0 for deme in range(source.n_demes): blocks = self.state_space.model.coalesce( self.state_space.lineage_config.n, source.lineages[locus, deme] ) for block, rate in blocks: target = source.copy() target.lineages[locus, deme] = block time_scale = self.state_space.model._get_timescale(pop_sizes[deme]) self.add_target(targets, target, rate / time_scale, 'coalescence') return targets if source.n_loci == 2: if not isinstance(self.state_space, LineageCountingStateSpace): raise NotImplementedError( 'Coalescence with recombination is only implemented for LineageCountingStateSpace.' ) if not isinstance(self.state_space.model, StandardCoalescent): raise NotImplementedError('Coalescence with recombination is only implemented for StandardCoalescent.') bins = dict( linked=source.linked[0], unlinked1=source.unlinked[0], unlinked2=source.unlinked[1] ) for deme in range(source.n_demes): time_scale = self.state_space.model._get_timescale(pop_sizes[deme]) for ((class1, counts1), (class2, counts2)) in product(bins.items(), repeat=2): target = source.copy() # linked or unlinked coalescence if class1 == class2: # we need at least 2 lineages to coalesce if np.any(counts1[deme] < 2): continue # if the classes are the same, the counts are the same rate = self.state_space.model._get_rate(b=counts1[deme, 0], k=2) # unlinked coalescence in locus 1 if 'unlinked1' in class1: target.lineages[0, deme] -= 1 self.add_target(targets, target, rate / time_scale, 'unlinked_coalescence') # unlinked coalescence in locus 2 elif 'unlinked2' in class1: target.lineages[1, deme] -= 1 self.add_target(targets, target, rate / time_scale, 'unlinked_coalescence') # linked coalescence in both loci elif np.all(source.linked[:, deme] > 0): target.lineages[:, deme] -= 1 target.linked[:, deme] -= 1 self.add_target(targets, target, rate / time_scale, 'linked_coalescence') # mixed or locus coalescence # use lower than operator to ensure we only consider each case once elif class1 < class2: if counts1[deme] < 1 or counts2[deme] < 1: continue rate = counts1[deme, 0] * counts2[deme, 0] # mixed coalescence of linked and unlinked lineages if 'linked' in (class1, class2) and ('unlinked' in class1 or 'unlinked' in class2): locus = 0 if '1' in class1 or '1' in class2 else 1 # condition already checked above if target.lineages[locus, deme, 0] > 1: target.lineages[locus, deme, 0] -= 1 self.add_target(targets, target, rate / time_scale, 'mixed_coalescence') # locus coalescence of unlinked lineages else: # make sure we have unlinked lineages in both loci if np.all(source.unlinked[:, deme, 0] > 0): target.linked[:, deme, 0] += 1 self.add_target(targets, target, rate / time_scale, 'locus_coalescence') return targets raise NotImplementedError('Coalescence is not implemented for more than 2 loci.')
def _coalesce_joint(self, source: 'State') -> Dict['State', Tuple[float, str]]: """ Get all possible coalescent transitions for the joint block-counting state space. Lineages residing in the same deme coalesce; the descendant vector of the merged lineage is the (vector) sum of the descendant vectors of the merging lineages. Rates are obtained from the coalescent model in the same way as for the block-counting state space, so this works for the standard coalescent as well as multiple-merger models (combinations the model does not support simply have rate zero and are skipped). :param source: Source state. :return: All possible coalescent transitions from the given state. """ targets: Dict['State', Tuple[float, str]] = {} model = self.state_space.model pop_sizes = [self.state_space.epoch.pop_sizes[pop] for pop in self.state_space.lineage_config.pop_names] block_vectors = self.state_space.block_vectors block_index = self.state_space.block_index locus = 0 for deme in range(source.n_demes): counts = source.lineages[locus, deme] deme_total = int(counts.sum()) # we need at least two lineages in a deme to coalesce if deme_total < 2: continue present = np.where(counts > 0)[0] time_scale = model._get_timescale(pop_sizes[deme]) # enumerate how many lineages to merge from each present block for comb in product(*[range(int(counts[b]) + 1) for b in present]): comb = np.array(comb) # at least two lineages have to merge if comb.sum() < 2: continue mask = comb > 0 rate = model._get_rate_block_counting(n=deme_total, b=counts[present][mask], k=comb[mask]) # skip combinations the model does not support (e.g. multiple mergers under the standard coalescent) if rate == 0: continue # descendant vector of the merged lineage is the sum of the merging descendant vectors merged = (comb[:, None] * block_vectors[present]).sum(axis=0) merged_idx = block_index[tuple(int(x) for x in merged)] target = source.copy() target.lineages[locus, deme, present] -= comb target.lineages[locus, deme, merged_idx] += 1 self.add_target(targets, target, rate / time_scale, 'coalescence') return targets
[docs] def migrate(self, source: 'State') -> Dict['State', Tuple[float, str]]: """ Get all possible migration transitions from the given state. :param source: Source state. :return: All possible migration transitions from the given state. """ return self.migrate_linked(source) | self.migrate_unlinked(source)
[docs] def migrate_unlinked(self, source: 'State') -> Dict['State', Tuple[float, str]]: """ Get all possible unlinked migration transitions from the given state. Note that we also consider migration to unlinked when there is only one locus. :param source: Source state. :return: All possible migration transitions from the given state. """ targets: Dict['State', Tuple[float, str]] = {} pop_names = self.state_space.lineage_config.pop_names kind = 'migration' if source.n_loci == 1 else 'unlinked_migration' for locus in range(source.n_loci): for d1, d2 in filter(lambda x: x[0] != x[1], product(range(source.n_demes), repeat=2)): for block in range(source.n_blocks): # skip if no lineages to migrate if source.lineages[locus, d1, block] > 0 and source.unlinked[locus, d1, block] > 0: target = source.copy() target.lineages[locus, d1, block] -= 1 target.lineages[locus, d2, block] += 1 base_rate = self.state_space.epoch.migration_rates[(pop_names[d1], pop_names[d2])] # scale migration rate by number of lineages in source deme rate = base_rate * cast(int, source.unlinked[locus, d1, block]) self.add_target(targets, target, rate, kind) return targets
[docs] def migrate_linked(self, source: 'State') -> Dict['State', Tuple[float, str]]: """ Get all possible linked migration transitions from the given state. :param source: Source state. :return: All possible migration transitions from the given state. """ targets: Dict['State', Tuple[float, str]] = {} # no linked migration if there is only one locus if source.n_loci == 1: return targets pop_names = self.state_space.lineage_config.pop_names for d1, d2 in filter(lambda x: x[0] != x[1], product(range(source.n_demes), repeat=2)): for block in range(source.n_blocks): # skip if no lineages to migrate if np.all(source.lineages[:, d1, block]) > 0 and np.all(source.linked[:, d1, block] > 0): target = source.copy() target.lineages[:, d1, block] -= 1 target.lineages[:, d2, block] += 1 target.linked[:, d1, block] -= 1 target.linked[:, d2, block] += 1 base_rate = self.state_space.epoch.migration_rates[(pop_names[d1], pop_names[d2])] # scale migration rate by number of lineages in source deme # both loci are assumed to have the same number of linked lineages here rate = base_rate * cast(int, source.linked[0, d1, block]) self.add_target(targets, target, rate, 'linked_migration') return targets
[docs] def recombine(self, state: 'State') -> Dict['State', Tuple[float, str]]: """ Get all possible recombination transitions from the given state. :param state: State. :return: All possible recombination transitions from the given state. """ targets: Dict['State', Tuple[float, str]] = {} r = self.state_space.locus_config.recombination_rate # only recombine if there is more than one locus if self.state_space.locus_config.n == 1: return targets if isinstance(self.state_space, LineageCountingStateSpace): # iterate over demes for deme in range(state.n_demes): # make sure we have linked lineages that can recombine if np.all(state.linked[:, deme] > 0): target = state.copy() target.linked[:, deme] -= 1 rate = r * state.linked[0, deme, 0] self.add_target(targets, target, cast(float, rate), 'recombination') return targets if isinstance(self.state_space, TwoLocusBlockCountingStateSpace): block_vectors = self.state_space.block_vectors block_index = self.state_space.block_index lineages = state.lineages[0, 0] # each linked lineage (a_0, a_1) with both components > 0 splits into (a_0, 0) and (0, a_1) for block in range(self.state_space.n_blocks): a0, a1 = int(block_vectors[block, 0]), int(block_vectors[block, 1]) if lineages[block] > 0 and a0 > 0 and a1 > 0: target = state.copy() target.lineages[0, 0, block] -= 1 target.lineages[0, 0, block_index[(a0, 0)]] += 1 target.lineages[0, 0, block_index[(0, a1)]] += 1 self.add_target(targets, target, r * int(lineages[block]), 'recombination') return targets raise NotImplementedError(f'Recombination is not yet implemented for {self.state_space.__class__.__name__}.')
[docs] class State: """ State utility class. """ #: Axis for linkage. LINKAGE = 0 #: Axis for loci. LOCUS = 1 #: Axis for demes. DEME = 2 #: Axis for lineage blocks. BLOCK = 3
[docs] def __init__(self, data: (np.ndarray, np.ndarray)): """ Initialize a state. :param data: State data. """ #: State data self.data: Tuple[np.ndarray, np.ndarray] = data
def __hash__(self) -> int: """ Hash function. :return: Hash of the state. """ return hash((self.data[0].tobytes(), self.data[1].tobytes())) def __eq__(self, other: 'State') -> bool: """ Check if two states are equal. :param other: Other state. :return: Whether the two states are equal. """ return hash(self) == hash(other)
[docs] def copy(self) -> 'State': """ Copy the state. :return: Copy of the state. """ return State((self.data[0].copy(), self.data[1].copy()))
[docs] def is_absorbing(self) -> bool: """ Whether a state is absorbing. :return: Whether the state is absorbing. """ return np.all(self.lineages.sum(axis=(1, 2)) == 1)
@property def n_demes(self) -> int: """ Get the number of demes. :return: The number of demes. """ return self.lineages.shape[1] @property def n_loci(self) -> int: """ Get the number of loci. :return: The number of loci. """ return self.lineages.shape[0] @property def n_blocks(self) -> int: """ Get the number of lineage blocks. :return: The number of lineage blocks. """ return self.lineages.shape[2] @property def lineages(self) -> np.ndarray: """ Get the number of lineages. :return: The number of lineages. """ return self.data[0] @property def linked(self) -> np.ndarray: """ Get the number of linked lineages. :return: The number of linked lineages. """ return self.data[1] @property def unlinked(self) -> np.ndarray: """ Get the number of unlinked lineages. :return: The number of unlinked lineages. """ return self.lineages - self.linked