"""
State space classes and utilities. The two main state space classes are
:class:`LineageCountingStateSpace` and :class:`BlockCountingStateSpace`.
"""
import logging
import time
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import product
from typing import List, Tuple, Dict, Callable, cast
import numpy as np
from tqdm import tqdm
from .coalescent_models import CoalescentModel, StandardCoalescent, BetaCoalescent, DiracCoalescent
from .settings import Settings
from .demography import Epoch
from .lineage import LineageConfig
from .locus import LocusConfig
from .state_space_numba import HAS_NUMBA, build_rate_matrix
from .state_space_old import StateSpace as OldStateSpace, LineageCountingStateSpace as OldLineageCountingStateSpace, \
BlockCountingStateSpace as OldBlockCountingStateSpace
logger = logging.getLogger('phasegen')
def _numba_model_params(model: CoalescentModel) -> Tuple[int, float, float, float]:
"""
Pack a coalescent model into ``(model_id, alpha, psi, c)`` for the numba kernels (0 standard, 1 beta, 2 dirac).
"""
if isinstance(model, BetaCoalescent):
return 1, model.alpha, 0.0, 0.0
if isinstance(model, DiracCoalescent):
return 2, 0.0, model.psi, model.c
return 0, 0.0, 0.0, 0.0
[docs]
class StateSpace(ABC):
"""
State space.
"""
[docs]
def __init__(
self,
lineage_config: LineageConfig,
locus_config: LocusConfig = None,
model: CoalescentModel = None,
epoch: Epoch = None
):
"""
Create a rate matrix.
:param lineage_config: Population configuration.
:param locus_config: Locus configuration. One locus is used by default.
:param model: Coalescent model. By default, the standard coalescent is used.
:param epoch: The epoch.
"""
if locus_config is None:
locus_config = LocusConfig()
if model is None:
model = StandardCoalescent()
if epoch is None:
epoch = Epoch()
#: Logger
self._logger = logger.getChild(self.__class__.__name__)
#: Coalescent model
self.model: CoalescentModel = model
#: Population configuration
self.lineage_config: LineageConfig = lineage_config
#: Locus configuration
self.locus_config: LocusConfig = locus_config
#: Epoch
self.epoch: Epoch = epoch
#: Cached rate matrices
self._cache: Dict[Epoch, Tuple[Dict[Tuple['State', 'State'], Tuple[float, str]], List['State']]] = {}
# time in seconds to compute original rate matrix
self.time: float | None = None
@cached_property
def states(self) -> List['State']:
"""
The states.
"""
start = time.time()
if self._use_numba():
states, S = self._construct_numba()
# the states are epoch-independent, but prime the current epoch's rate matrix to avoid rebuilding it
self.__dict__.setdefault('S', S)
else:
# get all possible transitions
transitions, states = self.get_transitions()
# cache rate matrix if specified
if Settings.cache_epochs:
self._cache[self.epoch] = (transitions, states)
# record time to compute rate matrix
self.time = time.time() - start
return states
@cached_property
def lineages(self) -> np.ndarray:
"""
The lineage configurations. Each configuration describes the lineages per block, deme and locus, i.e.,
``[[[a_ijk]]]`` for block ``i``, deme ``j`` and locus ``k``.
"""
return np.array([s.lineages for s in self.states])
@cached_property
def linked(self) -> np.ndarray:
"""
The linked lineages per block, deme and locus.
:return:
"""
return np.array([s.linked for s in self.states])
@cached_property
def unlinked(self) -> np.ndarray:
"""
Unlinked lineages.
"""
return self.lineages - self.linked
@abstractmethod
def _get_old(self) -> OldStateSpace:
"""
Get the old state space.
"""
pass
def _get_old_ordering(self) -> List[int]:
"""
Get the ordering of the states in the old state space relative to the new state space.
:return: Ordering of the states in the old state space.
"""
old = self._get_old()
# reorder the states of s2 to match s1
return cast(List[int], [
np.where(((old.states == self.lineages[i]) & (old.linked == self.linked[i])).all(axis=(1, 2, 3)))[0][0]
for i in range(self.k)
])
@staticmethod
def _get_partitions(n: int, k: int) -> List[List[int]]:
"""
Find all vectors of length `k` with non-negative integers that sum to `n`.
:param n: The sum.
:param k: The length of the vectors.
:return: All vectors of length `k` with non-negative integers that sum to `n`.
"""
if k == 0:
return [[]]
if k == 1:
return [[n]]
vectors = []
for i in range(n + 1):
for vector in StateSpace._get_partitions(n - i, k - 1):
vectors.append(vector + [i])
return vectors
[docs]
def get_transitions(self) -> Tuple[Dict[Tuple['State', 'State'], Tuple[float, str]], List['State']]:
"""
Get all possible transitions from the given state.
:return: All possible transitions from the given state.
"""
sources = [self._get_initial()]
transitions = {}
visited = []
# number of states and transitions
i, j = 0, 0
# backward compatibility
pbar = tqdm(desc=f'{self.__class__.__name__}: transitions', disable=not Settings.use_pbar)
while True:
targets_new = {}
for source in sources:
# skip if source has been visited already
if source in visited:
continue
# get all possible transitions from source
targets = self.transition.transit(source)
# add visited source state
visited += [source]
# add transitions to dictionary
for target, transition in targets.items():
transitions[(source, target)] = transition
j += 1
pbar.update(1)
# add targets to new targets
targets_new |= targets
# increment state counter
i += 1
if i in [1000, 5000, 25000]:
levels = {1000: 'slow', 5000: 'very slow', 25000: 'extremely slow'}
self._logger.warning(
f'State space size exceeds {i} states. Computations may be {levels[i]}.'
)
# break if no more targets
if len(targets_new) == 0:
break
# take new targets as source states
sources = tuple(targets_new.keys())
pbar.set_description_str(f'{self.__class__.__name__}: ({i} states, {j} transitions)')
pbar.close()
# warn if state space is large
if (k := len(visited)) > 400:
self._logger.warning(f'State space is large ({k} states, {len(transitions)} transitions). '
f'Note that the computation time increases '
f'exponentially with the number of states.')
return transitions, visited
@cached_property
def e(self) -> np.ndarray:
"""
Vector with ones of size ``k``.
"""
return np.ones(self.k)
@cached_property
def S(self) -> np.ndarray:
"""
Intensity matrix.
"""
return self._get_rate_matrix()
@cached_property
def alpha(self) -> np.ndarray:
"""
Initial state vector.
"""
pops = self.lineage_config._get_initial_states(self)
loci = self.locus_config._get_initial_states(self)
# combine initial states
alpha = pops * loci
# return normalized vector
# normalization ensures that the initial state vector is a probability distribution
# as we may have multiple initial states
return alpha / alpha.sum()
@cached_property
def k(self) -> int:
"""
Number of states.
"""
return len(self.states)
@cached_property
def transition(self) -> 'Transition':
"""
Transition.
"""
return Transition(self)
[docs]
def update_epoch(self, epoch: Epoch):
"""
Update the epoch.
:param epoch: Epoch.
:return: State space.
"""
# only remove cached properties if epoch has changed
if self.epoch != epoch:
# update S by rescaling if already cached, provided there is only one population and one locus
if (
self.lineage_config.n_pops == 1 and
self.locus_config.n == 1 and
'S' in self.__dict__
):
self.S *= self._get_scaling_factor(self.epoch, epoch)
else:
self.drop_S()
self.epoch = epoch
def _get_scaling_factor(self, epoch_prev: Epoch, epoch_next: Epoch) -> float:
"""
Get the scaling factor for the rate matrix when changing epochs.
:param epoch_prev: Previous epoch.
:param epoch_next: Next epoch.
:return: Scaling factor.
"""
pop_prev = epoch_prev.pop_sizes[epoch_prev.pop_names[0]]
pop_next = epoch_next.pop_sizes[epoch_next.pop_names[0]]
return self.model._get_timescale(pop_prev) / self.model._get_timescale(pop_next)
def __eq__(self, other):
"""
Check if two state spaces are equal. We do not check for equivalence of the epochs as we can
update the epoch of a state space dynamically.
:param other: Other state space
:return: Whether the two state spaces are equal
"""
return (
self.__class__ == other.__class__ and
self.lineage_config == other.lineage_config and
self.locus_config == other.locus_config and
self.model == other.model
)
[docs]
def drop_S(self):
"""
Drop the current rate matrix.
"""
try:
# noinspection all
del self.S
except AttributeError:
pass
[docs]
def drop_cache(self):
"""
Drop the rate matrix cache and current rate matrix.
"""
self.drop_S()
self._cache = {}
@abstractmethod
def _get_initial(self):
"""
Get the initial state.
"""
pass
def _is_absorbing(self, state: 'State') -> bool:
"""
Whether the given state is absorbing. By default this is the single-process absorbing condition (a single
remaining lineage); state spaces with a different notion of absorption (e.g. two loci, which are absorbed
once both have reached their MRCA) override this.
:param state: State.
:return: Whether the state is absorbing.
"""
return state.is_absorbing()
def _use_numba(self) -> bool:
"""
Whether numba-accelerated construction applies: numba is available and enabled, there is a single locus, and
the state space is one of the supported types (the 2-locus recombination path stays on the Python
construction).
"""
if not (HAS_NUMBA and Settings.use_numba):
return False
# the single-locus state spaces (numbered 0/1); the two-locus space is handled separately below
if self.locus_config.n == 1 and type(self) in (
LineageCountingStateSpace, BlockCountingStateSpace, JointBlockCountingStateSpace):
return True
# the two-locus block-counting state space (recombination)
return type(self) is TwoLocusBlockCountingStateSpace
def _numba_kind(self) -> int:
"""
Kernel selector: 0 lineage-counting, 1 block-/joint-counting, 2 two-locus block-counting.
"""
return 1
def _numba_recombination(self) -> Tuple[float, np.ndarray, np.ndarray]:
"""
Recombination parameters for the kernel: ``(recombination_rate, recomb0, recomb1)`` where ``recomb_l[b]`` is
the block index that block ``b`` contributes to locus ``l`` when it recombines. Only the two-locus state
space uses these; by default there is no recombination.
"""
return 0.0, None, None
def _numba_block_vectors(self, n_blocks: int) -> np.ndarray:
"""
Block labels passed to the kernel; for block-counting, block ``i`` represents lineages subtending ``i + 1``
samples.
"""
return np.arange(1, n_blocks + 1, dtype=np.int64).reshape(-1, 1)
def _construct_numba(self) -> Tuple[List['State'], np.ndarray]:
"""
Build the states and rate matrix for the current epoch via the numba kernel.
:return: The states (in kernel discovery order) and the dense intensity matrix.
"""
init = self._get_initial()
n_demes = init.lineages.shape[1]
n_blocks = init.lineages.shape[2]
pops = self.lineage_config.pop_names
mig = np.zeros((n_demes, n_demes))
for i, a in enumerate(pops):
for j, b in enumerate(pops):
if i != j:
mig[i, j] = self.epoch.migration_rates[(a, b)]
timescales = np.array([self.model._get_timescale(self.epoch.pop_sizes[p]) for p in pops])
model_id, alpha, psi, c = _numba_model_params(self.model)
recomb_rate, recomb0, recomb1 = self._numba_recombination()
rows, S = build_rate_matrix(
initial=init.lineages.reshape(-1),
kind=self._numba_kind(),
n_demes=n_demes,
n_blocks=n_blocks,
mig=mig,
timescales=timescales,
model_id=model_id,
alpha=alpha,
psi=psi,
c=c,
block_vectors=self._numba_block_vectors(n_blocks),
recomb_rate=recomb_rate,
recomb0=recomb0,
recomb1=recomb1,
)
lin_shape = init.lineages.shape
lin_dtype = init.lineages.dtype
linked = init.linked # all-zero for a single locus
states = [State((row.reshape(lin_shape).astype(lin_dtype), linked.copy())) for row in rows]
return states, S
def _get_rate_matrix(self) -> np.ndarray:
"""
Get the rate matrix.
TODO don’t compute transitions twice for disabled caching
:return: The rate matrix.
"""
if self._use_numba():
states, S = self._construct_numba()
self.__dict__.setdefault('states', states)
return S
# check if epoch is in cache
if Settings.cache_epochs and self.epoch in self._cache:
transitions, states = self._cache[self.epoch]
else:
# get all possible transitions
transitions, states = self.get_transitions()
# cache rate matrix if specified
if Settings.cache_epochs:
self._cache[self.epoch] = (transitions, states)
return self._graph_to_matrix(transitions, states)
@staticmethod
def _graph_to_matrix(
transitions: Dict[Tuple['State', 'State'], Tuple[float, str]],
states: List['State']
) -> np.ndarray:
"""
Convert transition graph to rate matrix.
:param transitions: Transitions.
:param states: States.
:return: Rate matrix.
"""
S = np.zeros((len(states), len(states)))
# order of original states
ordering = {s: i for i, s in enumerate(states)}
# fill rate matrix
for (source, target), transition in transitions.items():
S[ordering[source], ordering[target]] = transition[0]
# fill diagonal with negative sum of row
S[np.diag_indices_from(S)] = -np.sum(S, axis=1)
return S
[docs]
def get_sparsity(self) -> float:
"""
Get the sparsity of the rate matrix.
:return: The sparsity.
"""
return 1 - np.count_nonzero(self.S) / self.S.size
def _get_color_state(self, i: int) -> str:
"""
Get color of the state indexed by `i`.
"""
if self.states[i].is_absorbing():
return '#f1807e'
if self.alpha[i] > 0:
return 'lightgreen'
return 'lightblue'
[docs]
def plot_rates(
self,
file: str,
view: bool = True,
cleanup: bool = False,
dpi: int = 400,
ratio: float = 0.6,
background_color: str = 'white',
extension: str = 'png',
format_state: Callable[[np.array], str] = None,
format_transition: Callable[['Transition'], str] = None
):
"""
Plot the rate matrix using graphviz. Note that graphviz must be installed which is an external dependency.
:param file: File to save plot to.
:param view: Whether to view the plot.
:param cleanup: Whether to remove the source file.
:param dpi: Dots per inch.
:param ratio: Aspect ratio.
:param background_color: Background color.
:param extension: File format.
:param format_state: Function to format state with state array as argument.
:param format_transition: Function to format transition with transition as argument.
"""
import graphviz
if format_state is None:
def format_state(s: Tuple[np.ndarray, np.ndarray]) -> str:
"""
Format state.
:param s: State.
:return: Formatted state.
"""
return str(s[0]).replace('\n', '') + '\n' + str(s[1]).replace('\n', '')
if format_transition is None:
def format_transition(rate: float, kind: str) -> str:
"""
Format transition.
:param rate: Rate.
:param kind: Kind.
:return: Formatted transition.
"""
return f' {kind}: ' + '{:.2f}'.format(rate).rstrip('0').rstrip('.')
graph = graphviz.Digraph()
# add nodes
for i, state in enumerate(self.states):
graph.node(
name=format_state(state.data),
fillcolor=self._get_color_state(i),
style='filled'
)
transitions, _ = self.get_transitions()
# add non-zero edges
for (source, target), transition in transitions.items():
if not source.is_absorbing():
graph.edge(
tail_name=format_state(source.data),
head_name=format_state(target.data),
label=format_transition(*transition),
color=Transition._colors[transition[1]],
fontcolor=Transition._colors[transition[1]]
)
graph.graph_attr['dpi'] = str(dpi)
graph.graph_attr['ratio'] = str(ratio)
graph.graph_attr['bgcolor'] = background_color
graph.render(
filename=file,
view=view,
cleanup=cleanup,
format=extension
)
[docs]
class LineageCountingStateSpace(StateSpace):
"""
Default rate matrix where there is one state per number of lineages for each deme and locus.
"""
def _get_initial(self) -> 'State':
"""
Get the initial state.
"""
data = tuple(np.zeros((self.locus_config.n, self.lineage_config.n_pops, 1), dtype=int) for _ in range(2))
data[0][:, 0, 0] = self.lineage_config.n
return State(data)
def _numba_kind(self) -> int:
"""
Lineage-counting kernel.
"""
return 0
def _numba_block_vectors(self, n_blocks: int) -> np.ndarray:
"""
Lineage counting has a single block; the label is unused by the lineage kernel.
"""
return np.array([[1]], dtype=np.int64)
def _get_old(self) -> OldLineageCountingStateSpace:
"""
Get the old state space.
"""
return OldLineageCountingStateSpace(
lineage_config=self.lineage_config,
locus_config=self.locus_config,
model=self.model,
epoch=self.epoch
)
[docs]
class BlockCountingStateSpace(StateSpace):
r"""
Rate matrix for block-counting state space where there is one state per sample configuration:
A block-counting state is a vector of length ``n`` where each element represents the number of lineages
subtending ``i`` lineages in the coalescent tree.
.. math::
(a_1,...,a_n) \in \mathbb{Z}_+^n : \sum_{i=1}^{n} i a_i = n.
per deme and per locus. This state space can distinguish between different tree topologies
and is thus used when computing statistics based on the SFS.
"""
[docs]
def __init__(
self,
lineage_config: LineageConfig,
locus_config: LocusConfig = None,
model: CoalescentModel = None,
epoch: Epoch = None
):
"""
Create a rate matrix.
:param lineage_config: Population configuration.
:param locus_config: Locus configuration. One locus is used by default.
:param model: Coalescent model. By default, the standard coalescent is used.
:param epoch: The epoch
"""
# currently only one locus is supported, due to a very complex state space for multiple loci
if locus_config is not None and locus_config.n > 1:
raise NotImplementedError('Block-counting state space only supports one locus.')
super().__init__(
lineage_config=lineage_config,
locus_config=locus_config,
model=model,
epoch=epoch
)
def _get_initial(self) -> 'State':
"""
Get the initial state.
"""
data = tuple(
np.zeros((self.locus_config.n, self.lineage_config.n_pops, self.lineage_config.n), dtype=int)
for _ in range(2)
)
data[0][:, 0, 0] = self.lineage_config.n
return State(data)
@staticmethod
def _traverse(
start_state: int,
start_p: float,
S: np.ndarray,
absorbing_states: np.ndarray,
lineage_counts: np.ndarray,
) -> np.ndarray:
"""
Calculate the probabilities of being in each state conditioned on the number of lineages.
:param start_state: Index of the starting state.
:param start_p: Probability of starting in the starting state.
:param S: Transition matrix.
:param absorbing_states: Indices of absorbing states.
:param lineage_counts: Number of lineages in each state.
:return: State probabilities conditioned on the number of lineages.
"""
probs = np.zeros(S.shape[0])
probs[start_state] = start_p
state_indices = np.arange(lineage_counts.shape[0])
# descending lineage counts
unique_counts = sorted(set(lineage_counts), reverse=True)
for k in unique_counts:
# iterate over states with k lineages
for i in state_indices[lineage_counts == k]:
if i in absorbing_states or S[i, i] == 0:
continue
trans_probs = S[i] / -S[i, i]
for j in np.where(trans_probs > 0)[0]:
probs[j] += probs[i] * trans_probs[j]
return probs
@cached_property
def _state_probs(self) -> np.ndarray:
"""
Get state probabilities conditioned on the number of lineages.
This can be used to flatten the block-counting state space to a lineage-counting state space by weighting the
lineages by the probabilities of being in each of the corresponding block-counting states.
This only works for one-population, one-locus state spaces under the standard coalescent, or MMCs provided
there is only one epoch, and we accumulate until absorption.
:return: State probabilities conditioned on the number of lineages.
"""
self._logger.debug('Calculating state probabilities conditioned on the number of lineages.')
absorbing_states = np.where([s.is_absorbing() for s in self.states])[0]
probs = np.zeros(self.k)
lineage_counts = np.array([s.lineages.sum() for s in self.states])
state_indices = np.arange(lineage_counts.shape[0])
for i, p in zip(state_indices[self.alpha > 0], self.alpha[self.alpha > 0]):
probs += self._traverse(i, p, self.S, absorbing_states, lineage_counts)
return probs
def _get_old(self) -> OldBlockCountingStateSpace:
"""
Get the old state space.
"""
return OldBlockCountingStateSpace(
lineage_config=self.lineage_config,
locus_config=self.locus_config,
model=self.model,
epoch=self.epoch
)
[docs]
class JointBlockCountingStateSpace(StateSpace):
r"""
Rate matrix for the joint (multi-population) site-frequency spectrum.
This is a generalization of :class:`BlockCountingStateSpace`. In the block-counting state space a block is a
single *size class* ``i`` (the number of sampled lineages a lineage subtends), which discards the information of
*which* population those descendants came from. The joint SFS bins mutations by their allele frequency in each
population simultaneously, so each block must instead be the **descendant vector**
.. math::
v = (v_0, \dots, v_{P-1}), \quad 0 \le v_p \le n_p, \quad 1 \le \sum_p v_p \le n,
i.e. the number of descendants a lineage subtends from each population ``p`` (its "deme of origin"
composition). A state then counts, per locus and per *current* deme of residence, how many lineages carry each
descendant vector.
.. note::
Splitting each size class by deme of origin both enlarges the block axis
(``n_blocks = prod(n_p + 1) - 1``) and grows the number of reachable states combinatorially, much faster
than the single-population block-counting space. This state space is therefore only practical for small
per-population sample sizes. Only one locus is supported.
"""
[docs]
def __init__(
self,
lineage_config: LineageConfig,
locus_config: LocusConfig = None,
model: CoalescentModel = None,
epoch: Epoch = None
):
"""
Create a rate matrix.
:param lineage_config: Population configuration.
:param locus_config: Locus configuration. One locus is used by default.
:param model: Coalescent model. By default, the standard coalescent is used.
:param epoch: The epoch.
"""
# the joint state space tracks descendant vectors which do not extend to multiple loci
if locus_config is not None and locus_config.n > 1:
raise NotImplementedError('Joint block-counting state space only supports one locus.')
super().__init__(
lineage_config=lineage_config,
locus_config=locus_config,
model=model,
epoch=epoch
)
@cached_property
def block_configs(self) -> List[Tuple[int, ...]]:
"""
Ordered list of all descendant vectors (block types). A descendant vector ``(v_0,...,v_{P-1})`` has
``0 <= v_p <= n_p`` and at least one non-zero entry.
"""
sizes = [int(n_p) for n_p in self.lineage_config.lineages]
return [c for c in product(*[range(s + 1) for s in sizes]) if sum(c) >= 1]
@cached_property
def block_index(self) -> Dict[Tuple[int, ...], int]:
"""
Mapping from descendant vector to its block index.
"""
return {c: i for i, c in enumerate(self.block_configs)}
@cached_property
def block_vectors(self) -> np.ndarray:
"""
Descendant vectors as an integer array of shape ``(n_blocks, n_pops)``.
"""
return np.array(self.block_configs, dtype=int)
def _numba_block_vectors(self, n_blocks: int) -> np.ndarray:
"""
Block labels are the descendant vectors; the merged block of a coalescence is the one whose descendant
vector equals the sum of the merging vectors.
"""
return np.asarray(self.block_vectors, dtype=np.int64)
@cached_property
def n_blocks(self) -> int:
"""
Number of block types (descendant vectors).
"""
return len(self.block_configs)
def _get_initial(self) -> 'State':
"""
Get the initial state. Each of the ``n_p`` lineages sampled from population ``p`` starts in deme ``p`` with
descendant vector ``e_p`` (the unit vector subtending a single sample from population ``p``).
"""
data = tuple(
np.zeros((self.locus_config.n, self.lineage_config.n_pops, self.n_blocks), dtype=int)
for _ in range(2)
)
for p in range(self.lineage_config.n_pops):
if self.lineage_config.lineages[p] > 0:
e_p = tuple(1 if q == p else 0 for q in range(self.lineage_config.n_pops))
data[0][0, p, self.block_index[e_p]] = self.lineage_config.lineages[p]
return State(data)
@cached_property
def alpha(self) -> np.ndarray:
"""
Initial state vector. There is a single, origin-aware initial state (see :meth:`_get_initial`), so this is
its indicator vector.
"""
initial = self._get_initial()
return np.array([s == initial for s in self.states], dtype=float)
def _get_old(self) -> OldStateSpace:
"""
The joint block-counting state space has no legacy equivalent.
"""
raise NotImplementedError('The joint block-counting state space has no legacy equivalent.')
[docs]
class TwoLocusBlockCountingStateSpace(JointBlockCountingStateSpace):
r"""
Block-counting state space for two loci separated by recombination, used to compute the two-locus SFS.
Each physical ancestral lineage is described by a vector ``(a_0, a_1)`` giving the number of sampled lineages it
subtends at locus 0 and locus 1 (``a_l = 0`` meaning it is not ancestral at locus ``l``). This is the same
representation as the joint (multi-population) state space, with **locus** playing the role of **population**, so
coalescence is again vector addition (and reuses the model-agnostic merger rates, supporting Beta/Dirac too).
The new ingredient is **recombination**: a linked lineage ``(a_0, a_1)`` with ``a_0 > 0`` and ``a_1 > 0`` splits
into ``(a_0, 0)`` and ``(0, a_1)`` at rate ``r`` per linked lineage. The process is absorbed once both loci have
reached their MRCA.
A single population is currently supported (no migration), so the state shape collapses to ``(1, 1, n_blocks)``;
linkage is encoded in the block vector itself, so the ``linked`` array is unused.
.. note::
``n_blocks = (n + 1)^2 - 1`` and the number of reachable states grows quickly, so this is only practical for
small sample sizes.
"""
[docs]
def __init__(
self,
lineage_config: LineageConfig,
locus_config: LocusConfig = None,
model: CoalescentModel = None,
epoch: Epoch = None
):
"""
Create the two-locus block-counting state space.
:param lineage_config: Population configuration (a single population is currently supported).
:param locus_config: Locus configuration; must specify exactly two loci.
:param model: Coalescent model. By default, the standard coalescent is used.
:param epoch: The epoch.
"""
if locus_config is None or locus_config.n != 2:
raise ValueError('The two-locus block-counting state space requires exactly two loci.')
if lineage_config.n_pops != 1:
raise NotImplementedError('The two-locus block-counting state space currently supports a single '
'population (no migration).')
# bypass JointBlockCountingStateSpace.__init__ (which rejects more than one locus)
StateSpace.__init__(self, lineage_config=lineage_config, locus_config=locus_config, model=model, epoch=epoch)
@cached_property
def block_configs(self) -> List[Tuple[int, ...]]:
"""
Ordered list of all two-locus descendant vectors ``(a_0, a_1)`` with ``0 <= a_l <= n`` and at least one
non-zero entry.
"""
n = int(self.lineage_config.n)
return [c for c in product(range(n + 1), range(n + 1)) if sum(c) >= 1]
def _get_initial(self) -> 'State':
"""
Get the initial state: ``n - n_unlinked`` lineages of type ``(1, 1)`` (linked across both loci) plus, for
each of the ``n_unlinked`` initially unlinked samples, a ``(1, 0)`` and a ``(0, 1)`` lineage.
"""
n = int(self.lineage_config.n)
n_unlinked = int(self.locus_config.n_unlinked)
n_linked = max(n - n_unlinked, 0)
data = tuple(np.zeros((1, 1, self.n_blocks), dtype=int) for _ in range(2))
if n_linked > 0:
data[0][0, 0, self.block_index[(1, 1)]] = n_linked
if n_unlinked > 0:
data[0][0, 0, self.block_index[(1, 0)]] += n_unlinked
data[0][0, 0, self.block_index[(0, 1)]] += n_unlinked
return State(data)
def _is_absorbing(self, state: 'State') -> bool:
"""
A two-locus state is absorbing once both loci have reached their MRCA, i.e. exactly one lineage carries
ancestral material at locus 0 and exactly one carries it at locus 1 (covering both the single linked
grand-MRCA ``(n, n)`` and the unlinked pair ``(n, 0) + (0, n)``).
"""
lineages = state.lineages[0, 0]
return all(int(lineages[self.block_vectors[:, locus] > 0].sum()) == 1 for locus in range(2))
def _numba_kind(self) -> int:
"""
Two-locus block-counting kernel.
"""
return 2
def _numba_recombination(self) -> Tuple[float, np.ndarray, np.ndarray]:
"""
Recombination split maps: for each linked block ``(a_0, a_1)`` the indices of ``(a_0, 0)`` and ``(0, a_1)``.
"""
recomb0 = np.zeros(self.n_blocks, dtype=np.int64)
recomb1 = np.zeros(self.n_blocks, dtype=np.int64)
for b in range(self.n_blocks):
a0, a1 = int(self.block_vectors[b, 0]), int(self.block_vectors[b, 1])
if a0 > 0 and a1 > 0:
recomb0[b] = self.block_index[(a0, 0)]
recomb1[b] = self.block_index[(0, a1)]
return self.locus_config.recombination_rate, recomb0, recomb1
[docs]
class Transition:
"""
Class representing a transition between two states.
"""
#: Colors for different types of transitions
_colors: Dict[str, str] = {
'recombination': 'orange',
'coalescence': 'darkgreen',
'locus_coalescence': 'darkgreen',
'linked_coalescence': 'darkgreen',
'unlinked_coalescence': 'darkgreen',
'mixed_coalescence': 'darkgreen',
'unlinked_coalescence+mixed_coalescence': 'darkgreen',
'mixed_coalescence+unlinked_coalescence': 'darkgreen',
'linked_migration': 'blue',
'unlinked_migration': 'blue',
'migration': 'blue',
'invalid': 'red'
}
[docs]
def __init__(
self,
state_space: StateSpace
):
"""
Initialize a transition.
:param state_space: State space.
"""
#: State space.
self.state_space: StateSpace = state_space
[docs]
def transit(self, source: 'State') -> Dict['State', Tuple[float, str]]:
"""
Get all possible target states from the given source state.
:param source: Source state.
:return: All possible target states.
"""
targets: Dict['State', Tuple[float, str]] = {}
targets |= self.migrate(source)
if self.state_space._is_absorbing(source):
return targets
targets |= self.coalesce(source)
targets |= self.recombine(source)
return targets
[docs]
@staticmethod
def add_target(targets: Dict['State', Tuple[float, str]], target: 'State', rate: float, kind: str):
"""
Add a target state to the list of targets.
:param targets: Dictionary of target states.
:param target: New target state.
:param rate: Rate of the transition.
:param kind: Kind of the transition.
"""
if target in targets:
targets[target] = (targets[target][0] + rate, targets[target][1] + '+' + kind)
else:
targets[target] = (rate, kind)
[docs]
def coalesce(self, source: 'State') -> Dict['State', Tuple[float, str]]:
"""
Get all possible coalescent transitions from the given state.
:param source: Source state.
:return: All possible coalescent transitions from the given state.
"""
if isinstance(self.state_space, JointBlockCountingStateSpace):
return self._coalesce_joint(source)
targets: Dict['State', Tuple[float, str]] = {}
pop_sizes = [self.state_space.epoch.pop_sizes[pop] for pop in self.state_space.lineage_config.pop_names]
if source.n_loci == 1:
locus = 0
for deme in range(source.n_demes):
blocks = self.state_space.model.coalesce(
self.state_space.lineage_config.n,
source.lineages[locus, deme]
)
for block, rate in blocks:
target = source.copy()
target.lineages[locus, deme] = block
time_scale = self.state_space.model._get_timescale(pop_sizes[deme])
self.add_target(targets, target, rate / time_scale, 'coalescence')
return targets
if source.n_loci == 2:
if not isinstance(self.state_space, LineageCountingStateSpace):
raise NotImplementedError(
'Coalescence with recombination is only implemented for LineageCountingStateSpace.'
)
if not isinstance(self.state_space.model, StandardCoalescent):
raise NotImplementedError('Coalescence with recombination is only implemented for StandardCoalescent.')
bins = dict(
linked=source.linked[0],
unlinked1=source.unlinked[0],
unlinked2=source.unlinked[1]
)
for deme in range(source.n_demes):
time_scale = self.state_space.model._get_timescale(pop_sizes[deme])
for ((class1, counts1), (class2, counts2)) in product(bins.items(), repeat=2):
target = source.copy()
# linked or unlinked coalescence
if class1 == class2:
# we need at least 2 lineages to coalesce
if np.any(counts1[deme] < 2):
continue
# if the classes are the same, the counts are the same
rate = self.state_space.model._get_rate(b=counts1[deme, 0], k=2)
# unlinked coalescence in locus 1
if 'unlinked1' in class1:
target.lineages[0, deme] -= 1
self.add_target(targets, target, rate / time_scale, 'unlinked_coalescence')
# unlinked coalescence in locus 2
elif 'unlinked2' in class1:
target.lineages[1, deme] -= 1
self.add_target(targets, target, rate / time_scale, 'unlinked_coalescence')
# linked coalescence in both loci
elif np.all(source.linked[:, deme] > 0):
target.lineages[:, deme] -= 1
target.linked[:, deme] -= 1
self.add_target(targets, target, rate / time_scale, 'linked_coalescence')
# mixed or locus coalescence
# use lower than operator to ensure we only consider each case once
elif class1 < class2:
if counts1[deme] < 1 or counts2[deme] < 1:
continue
rate = counts1[deme, 0] * counts2[deme, 0]
# mixed coalescence of linked and unlinked lineages
if 'linked' in (class1, class2) and ('unlinked' in class1 or 'unlinked' in class2):
locus = 0 if '1' in class1 or '1' in class2 else 1
# condition already checked above
if target.lineages[locus, deme, 0] > 1:
target.lineages[locus, deme, 0] -= 1
self.add_target(targets, target, rate / time_scale, 'mixed_coalescence')
# locus coalescence of unlinked lineages
else:
# make sure we have unlinked lineages in both loci
if np.all(source.unlinked[:, deme, 0] > 0):
target.linked[:, deme, 0] += 1
self.add_target(targets, target, rate / time_scale, 'locus_coalescence')
return targets
raise NotImplementedError('Coalescence is not implemented for more than 2 loci.')
def _coalesce_joint(self, source: 'State') -> Dict['State', Tuple[float, str]]:
"""
Get all possible coalescent transitions for the joint block-counting state space. Lineages residing in the
same deme coalesce; the descendant vector of the merged lineage is the (vector) sum of the descendant vectors
of the merging lineages. Rates are obtained from the coalescent model in the same way as for the
block-counting state space, so this works for the standard coalescent as well as multiple-merger models
(combinations the model does not support simply have rate zero and are skipped).
:param source: Source state.
:return: All possible coalescent transitions from the given state.
"""
targets: Dict['State', Tuple[float, str]] = {}
model = self.state_space.model
pop_sizes = [self.state_space.epoch.pop_sizes[pop] for pop in self.state_space.lineage_config.pop_names]
block_vectors = self.state_space.block_vectors
block_index = self.state_space.block_index
locus = 0
for deme in range(source.n_demes):
counts = source.lineages[locus, deme]
deme_total = int(counts.sum())
# we need at least two lineages in a deme to coalesce
if deme_total < 2:
continue
present = np.where(counts > 0)[0]
time_scale = model._get_timescale(pop_sizes[deme])
# enumerate how many lineages to merge from each present block
for comb in product(*[range(int(counts[b]) + 1) for b in present]):
comb = np.array(comb)
# at least two lineages have to merge
if comb.sum() < 2:
continue
mask = comb > 0
rate = model._get_rate_block_counting(n=deme_total, b=counts[present][mask], k=comb[mask])
# skip combinations the model does not support (e.g. multiple mergers under the standard coalescent)
if rate == 0:
continue
# descendant vector of the merged lineage is the sum of the merging descendant vectors
merged = (comb[:, None] * block_vectors[present]).sum(axis=0)
merged_idx = block_index[tuple(int(x) for x in merged)]
target = source.copy()
target.lineages[locus, deme, present] -= comb
target.lineages[locus, deme, merged_idx] += 1
self.add_target(targets, target, rate / time_scale, 'coalescence')
return targets
[docs]
def migrate(self, source: 'State') -> Dict['State', Tuple[float, str]]:
"""
Get all possible migration transitions from the given state.
:param source: Source state.
:return: All possible migration transitions from the given state.
"""
return self.migrate_linked(source) | self.migrate_unlinked(source)
[docs]
def migrate_unlinked(self, source: 'State') -> Dict['State', Tuple[float, str]]:
"""
Get all possible unlinked migration transitions from the given state.
Note that we also consider migration to unlinked when there is only one locus.
:param source: Source state.
:return: All possible migration transitions from the given state.
"""
targets: Dict['State', Tuple[float, str]] = {}
pop_names = self.state_space.lineage_config.pop_names
kind = 'migration' if source.n_loci == 1 else 'unlinked_migration'
for locus in range(source.n_loci):
for d1, d2 in filter(lambda x: x[0] != x[1], product(range(source.n_demes), repeat=2)):
for block in range(source.n_blocks):
# skip if no lineages to migrate
if source.lineages[locus, d1, block] > 0 and source.unlinked[locus, d1, block] > 0:
target = source.copy()
target.lineages[locus, d1, block] -= 1
target.lineages[locus, d2, block] += 1
base_rate = self.state_space.epoch.migration_rates[(pop_names[d1], pop_names[d2])]
# scale migration rate by number of lineages in source deme
rate = base_rate * cast(int, source.unlinked[locus, d1, block])
self.add_target(targets, target, rate, kind)
return targets
[docs]
def migrate_linked(self, source: 'State') -> Dict['State', Tuple[float, str]]:
"""
Get all possible linked migration transitions from the given state.
:param source: Source state.
:return: All possible migration transitions from the given state.
"""
targets: Dict['State', Tuple[float, str]] = {}
# no linked migration if there is only one locus
if source.n_loci == 1:
return targets
pop_names = self.state_space.lineage_config.pop_names
for d1, d2 in filter(lambda x: x[0] != x[1], product(range(source.n_demes), repeat=2)):
for block in range(source.n_blocks):
# skip if no lineages to migrate
if np.all(source.lineages[:, d1, block]) > 0 and np.all(source.linked[:, d1, block] > 0):
target = source.copy()
target.lineages[:, d1, block] -= 1
target.lineages[:, d2, block] += 1
target.linked[:, d1, block] -= 1
target.linked[:, d2, block] += 1
base_rate = self.state_space.epoch.migration_rates[(pop_names[d1], pop_names[d2])]
# scale migration rate by number of lineages in source deme
# both loci are assumed to have the same number of linked lineages here
rate = base_rate * cast(int, source.linked[0, d1, block])
self.add_target(targets, target, rate, 'linked_migration')
return targets
[docs]
def recombine(self, state: 'State') -> Dict['State', Tuple[float, str]]:
"""
Get all possible recombination transitions from the given state.
:param state: State.
:return: All possible recombination transitions from the given state.
"""
targets: Dict['State', Tuple[float, str]] = {}
r = self.state_space.locus_config.recombination_rate
# only recombine if there is more than one locus
if self.state_space.locus_config.n == 1:
return targets
if isinstance(self.state_space, LineageCountingStateSpace):
# iterate over demes
for deme in range(state.n_demes):
# make sure we have linked lineages that can recombine
if np.all(state.linked[:, deme] > 0):
target = state.copy()
target.linked[:, deme] -= 1
rate = r * state.linked[0, deme, 0]
self.add_target(targets, target, cast(float, rate), 'recombination')
return targets
if isinstance(self.state_space, TwoLocusBlockCountingStateSpace):
block_vectors = self.state_space.block_vectors
block_index = self.state_space.block_index
lineages = state.lineages[0, 0]
# each linked lineage (a_0, a_1) with both components > 0 splits into (a_0, 0) and (0, a_1)
for block in range(self.state_space.n_blocks):
a0, a1 = int(block_vectors[block, 0]), int(block_vectors[block, 1])
if lineages[block] > 0 and a0 > 0 and a1 > 0:
target = state.copy()
target.lineages[0, 0, block] -= 1
target.lineages[0, 0, block_index[(a0, 0)]] += 1
target.lineages[0, 0, block_index[(0, a1)]] += 1
self.add_target(targets, target, r * int(lineages[block]), 'recombination')
return targets
raise NotImplementedError(f'Recombination is not yet implemented for {self.state_space.__class__.__name__}.')
[docs]
class State:
"""
State utility class.
"""
#: Axis for linkage.
LINKAGE = 0
#: Axis for loci.
LOCUS = 1
#: Axis for demes.
DEME = 2
#: Axis for lineage blocks.
BLOCK = 3
[docs]
def __init__(self, data: (np.ndarray, np.ndarray)):
"""
Initialize a state.
:param data: State data.
"""
#: State data
self.data: Tuple[np.ndarray, np.ndarray] = data
def __hash__(self) -> int:
"""
Hash function.
:return: Hash of the state.
"""
return hash((self.data[0].tobytes(), self.data[1].tobytes()))
def __eq__(self, other: 'State') -> bool:
"""
Check if two states are equal.
:param other: Other state.
:return: Whether the two states are equal.
"""
return hash(self) == hash(other)
[docs]
def copy(self) -> 'State':
"""
Copy the state.
:return: Copy of the state.
"""
return State((self.data[0].copy(), self.data[1].copy()))
[docs]
def is_absorbing(self) -> bool:
"""
Whether a state is absorbing.
:return: Whether the state is absorbing.
"""
return np.all(self.lineages.sum(axis=(1, 2)) == 1)
@property
def n_demes(self) -> int:
"""
Get the number of demes.
:return: The number of demes.
"""
return self.lineages.shape[1]
@property
def n_loci(self) -> int:
"""
Get the number of loci.
:return: The number of loci.
"""
return self.lineages.shape[0]
@property
def n_blocks(self) -> int:
"""
Get the number of lineage blocks.
:return: The number of lineage blocks.
"""
return self.lineages.shape[2]
@property
def lineages(self) -> np.ndarray:
"""
Get the number of lineages.
:return: The number of lineages.
"""
return self.data[0]
@property
def linked(self) -> np.ndarray:
"""
Get the number of linked lineages.
:return: The number of linked lineages.
"""
return self.data[1]
@property
def unlinked(self) -> np.ndarray:
"""
Get the number of unlinked lineages.
:return: The number of unlinked lineages.
"""
return self.lineages - self.linked