"""
Coalescent models.
"""
import itertools
from abc import ABC, abstractmethod
from typing import List, Tuple, Sequence
import numpy as np
from scipy.special import comb, beta
from scipy.stats import binom
[docs]
class CoalescentModel(ABC):
"""
Abstract class for coalescent models.
"""
def get_rate(self, s1: int, s2: int) -> float:
"""
Get rate for a merger collapsing k1 lineages into k2 lineages.
:param s1: Number of lineages in the first state.
:param s2: Number of lineages in the second state.
:return: The rate.
:meta private:
"""
# not possible
if s2 > s1:
return 0
return self._get_rate(b=s1, k=s1 + 1 - s2)
def get_rate_block_counting(self, n: int, s1: np.ndarray, s2: np.ndarray) -> float:
r"""
Get (positive) rate between two block-counting states.
A block-counting state is a vector of length ``n`` where each element represents the number of lineages
subtending ``i`` lineages in the coalescent tree.
.. math::
(a_1,...,a_n) \in \mathbb{Z}_+^n : \sum_{i=1}^{n} i a_i = n.
:param n: Total number of lineages.
:param s1: Block configuration 1, a vector of length n.
:param s2: Block configuration 2, a vector of length n.
:return: The rate.
:meta private:
"""
diff = s2 - s1
# make sure only one class has one more lineage
if np.sum(diff == 1) == 1 and n == s1.shape[0]:
# get the index for the class that lost lineages
where_less = np.where(diff < 0)[0]
# only continue if there is a class that lost lineages
if len(where_less) > 0:
# get the number of lineages that were lost
diff_less = -diff[where_less]
# determine the index of the class that gained lineages
i_more = np.dot(where_less + 1, diff_less) - 1
# make sure that the class that gained lineages only gained one lineage
if diff[i_more] == 1:
# number of lineages before the merger
b = s1[where_less]
# determine number of lineages that coalesce
k = b - s2[where_less]
# get rate
rate = self._get_rate_block_counting(n=s1.sum(), b=b, k=k)
return rate
return 0
@abstractmethod
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
pass
@abstractmethod
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
:param b: Number of lineages.
:param k: Number of lineages that merge.
:return: The rate.
"""
pass
@abstractmethod
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
pass
@abstractmethod
def coalesce(self, n: int, blocks: np.ndarray) -> List[Tuple[np.ndarray, float]]:
"""
Coalesce a state.
:param n: The total number of lineages.
:param blocks: The lineages in each block.
:return: List of coalesced states and their rates.
:meta private:
"""
pass
[docs]
class StandardCoalescent(CoalescentModel):
"""
Standard (Kingman) coalescent model. Refer to the
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
highlight=standard+coalescent#msprime.StandardCoalescent>`__
for more information.
"""
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
return N
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
:param b: Number of lineages.
:param k: Number of lineages that merge.
:return: The rate.
"""
# two lineages can merge with a rate depending on b
if k == 2:
return b * (b - 1) / 2
# no other mergers can happen
return 0
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
# if we have a single class
if len(b) == 1:
return self._get_rate(b=b[0], k=k[0])
# if we have a merger from two classes
if len(b) == 2:
if k[0] == 1 and k[1] == 1:
# same as b[0] choose k[0] times b[1] choose k[1]
return b[0] * b[1]
# no other mergers possible
return 0
def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
"""
Coalesce a state.
:param n: The total number of lineages.
:param blocks: The lineages in each block.
:return: List of coalesced states and their rates.
:meta private:
"""
n_blocks = len(blocks)
states = []
# lineage-counting state space
if n_blocks == 1:
if blocks[0] > 1:
states += [(np.array([blocks[0] - 1]), self._get_rate(b=blocks[0], k=2))]
return states
# block-counting state space
for i, j in itertools.product(range(n_blocks), repeat=2):
if i == j:
if blocks[i] > 1:
new = blocks.copy()
new[i] -= 2
new[2 * (i + 1) - 1] += 1
states += [(new, self._get_rate_block_counting(n=n, b=[blocks[i]], k=[2]))]
elif i > j:
if blocks[i] > 0 and blocks[j] > 0:
new = blocks.copy()
new[i] -= 1
new[j] -= 1
new[i + j + 1] += 1
rate = self._get_rate_block_counting(n=n, b=[blocks[i], blocks[j]], k=[1, 1])
states += [(new, rate)]
return states
def __eq__(self, other):
"""
Check if two coalescent models are equal.
:param other: The other coalescent model.
:return: Whether the two coalescent models are equal.
"""
return isinstance(other, StandardCoalescent)
class MultipleMergerCoalescent(CoalescentModel, ABC):
"""
Base class for multiple merger coalescent models.
:meta private:
"""
def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
"""
Coalesce a state.
:param n: The total number of lineages.
:param blocks: The lineages in each block.
:return: List of coalesced states and their rates.
:meta private:
"""
n_blocks = len(blocks)
states = []
# lineage-counting state space
if n_blocks == 1:
for k in range(1, blocks[0]):
states += [(np.array([blocks[0] - k]), self._get_rate(b=blocks[0], k=k + 1))]
return states
# block-counting state space
for comb in itertools.product(*[list(range(blocks[i] + 1)) for i in range(n_blocks)]):
comb = np.array(comb)
if comb.sum() > 1:
new = blocks.copy()
new -= comb
new[comb.dot(np.arange(1, n_blocks + 1)) - 1] += 1
rate = self._get_rate_block_counting(n=blocks.sum(), b=blocks[comb > 0], k=comb[comb > 0])
states += [(new, rate)]
return states
[docs]
class BetaCoalescent(MultipleMergerCoalescent):
"""
Beta coalescent model. Refer to the
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=beta+coalescent#msprime.BetaCoalescent>`__
for more information.
"""
[docs]
def __init__(self, alpha: float, scale_time: bool = True):
"""
Initialize the beta coalescent model.
:param alpha: The alpha parameter of the beta coalescent model.
:param scale_time: Whether to scale coalescence time as described in
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
highlight=beta+coalescent#msprime.BetaCoalescent>`__. If ``False``, the timescale is set to N.
"""
if alpha < 1 or alpha > 2:
raise ValueError("Alpha must be between 1 and 2.")
#: Whether to scale coalescence time.
self.scale_time: bool = scale_time
#: The alpha parameter of the beta coalescent model.
self.alpha: float = alpha
def _get_base_rate(self, b: int, k: int) -> float:
"""
Get base rate for a merger of k out of b lineages (without number of ways).
:param b: The number of lineages before the merger.
:param k: The number of lineages that merge.
:return: The rate.
"""
rate = beta(k - self.alpha, b - k + self.alpha) / beta(self.alpha, 2 - self.alpha)
return rate
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
if not self.scale_time:
return N
m = 1 + 1 / 2 ** (self.alpha - 1) / (self.alpha - 1)
scale = m ** self.alpha * N ** (self.alpha - 1) / self.alpha / beta(2 - self.alpha, self.alpha)
return scale
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
Negative rates will be filled in later.
:param b: The number of lineages before the merger.
:param k: The number of lineages that merge.
:return: The rate.
"""
if k < 1 or k > b:
return 0
return comb(b, k, exact=True) * self._get_base_rate(b, k)
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
combinations = np.prod([comb(N=b_i, k=k_i, exact=True) for b_i, k_i in zip(b, k)])
return combinations * self._get_base_rate(b=n, k=sum(k))
def __eq__(self, other):
"""
Check if two coalescent models are equal.
:param other: The other coalescent model.
:return: Whether the two coalescent models are equal.
"""
return (
isinstance(other, BetaCoalescent) and
self.alpha == other.alpha and
self.scale_time == other.scale_time
)
[docs]
class DiracCoalescent(MultipleMergerCoalescent):
"""
Dirac coalescent model. Refer to the
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=dirac+coalescent#msprime.DiracCoalescent>`__
for more information.
"""
[docs]
def __init__(self, psi: float, c: float, scale_time: bool = True):
"""
Initialize the Dirac coalescent model.
:param psi: The fraction of the population replaced by offspring in one large reproduction event.
:param c: The rate of potential multiple merger events.
:param scale_time: Whether to scale coalescence time as described in
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
highlight=dirac+coalescent#msprime.DiracCoalescent>`__. If `False`, the timescale is set to N.
"""
super().__init__()
if not 0 < psi < 1:
raise ValueError("Psi must be between 0 and 1.")
#: The fraction of the population replaced by offspring in one large reproduction event.
self.psi: float = psi
#: The rate of potential multiple merger events.
self.c: float = c
#: Whether to scale coalescence time.
self.scale_time: bool = scale_time
#: The standard coalescent model.
self._standard = StandardCoalescent()
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
if not self.scale_time:
return N
return N ** 2
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
Negative rates will be filled in later.
:param b: The number of lineages before the merger.
:param k: The number of lineages that merge.
:return: The rate.
"""
# rate of binary merger
rate_binary = self._standard._get_rate(b=b, k=k)
# probability of multiple merger of k out of b lineages
p_psi = binom.pmf(k=k, n=b, p=self.psi)
# rate of multiple merger
rate_multi = p_psi * self.c
return rate_binary + rate_multi
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
# rate of binary merger
rate_binary = self._standard._get_rate_block_counting(n=n, b=b, k=k)
# probability of multiple merger of k out of n lineages
# p_psi = binom.pmf(k=k.sum(), n=n, p=self.psi)
p_psi = np.prod([binom.pmf(k=k[i], n=b[i], p=self.psi) for i in range(len(k))])
# account for probability of no merger
if sum(b) < n:
p_psi *= binom.pmf(k=0, n=n - sum(b), p=self.psi)
# rate of multiple merger
rate_multi = p_psi * self.c
rate = rate_binary + rate_multi
return rate
def __eq__(self, other):
"""
Check if two coalescent models are equal.
:param other: The other coalescent model.
:return: Whether the two coalescent models are equal.
"""
return (
isinstance(other, DiracCoalescent) and
self.psi == other.psi and
self.c == other.c and
self.scale_time == other.scale_time
)