"""
Coalescent models for simulating genealogies. Besides the standard Kingman coalescent,
this module implements the beta coalescent and Dirac coalescent models.
"""
import itertools
from abc import ABC, abstractmethod
from typing import List, Tuple, Sequence
import numpy as np
from scipy.special import comb, beta
from scipy.stats import binom
[docs]
class CoalescentModel(ABC):
"""
Abstract class for coalescent models.
"""
def get_rate(self, s1: int, s2: int) -> float:
"""
Get rate for a merger collapsing k1 lineages into k2 lineages.
:param s1: Number of lineages in the first state.
:param s2: Number of lineages in the second state.
:return: The rate.
:meta private:
"""
# not possible
if s2 > s1:
return 0
return self._get_rate(b=s1, k=s1 + 1 - s2)
def get_rate_block_counting(self, n: int, s1: np.ndarray, s2: np.ndarray) -> float:
r"""
Get (positive) rate between two block-counting states.
A block-counting state is a vector of length ``n`` where each element represents the number of lineages
subtending ``i`` lineages in the coalescent tree.
.. math::
(a_1,...,a_n) \in \mathbb{Z}_+^n : \sum_{i=1}^{n} i a_i = n.
:param n: Total number of lineages.
:param s1: Block configuration 1, a vector of length n.
:param s2: Block configuration 2, a vector of length n.
:return: The rate.
:meta private:
"""
diff = s2 - s1
# make sure only one class has one more lineage
if np.sum(diff == 1) == 1 and n == s1.shape[0]:
# get the index for the class that lost lineages
where_less = np.where(diff < 0)[0]
# only continue if there is a class that lost lineages
if len(where_less) > 0:
# get the number of lineages that were lost
diff_less = -diff[where_less]
# determine the index of the class that gained lineages
i_more = np.dot(where_less + 1, diff_less) - 1
# make sure that the class that gained lineages only gained one lineage
if diff[i_more] == 1:
# number of lineages before the merger
b = s1[where_less]
# determine number of lineages that coalesce
k = b - s2[where_less]
# get rate
rate = self._get_rate_block_counting(n=s1.sum(), b=b, k=k)
return rate
return 0
@abstractmethod
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
pass
@abstractmethod
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
:param b: Number of lineages.
:param k: Number of lineages that merge.
:return: The rate.
"""
pass
@abstractmethod
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
pass
@abstractmethod
def coalesce(self, n: int, blocks: np.ndarray) -> List[Tuple[np.ndarray, float]]:
"""
Coalesce a state.
:param n: The total number of lineages.
:param blocks: The lineages in each block.
:return: List of coalesced states and their rates.
:meta private:
"""
pass
[docs]
class StandardCoalescent(CoalescentModel):
"""
Standard (Kingman) coalescent model. Refer to the
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
highlight=standard+coalescent#msprime.StandardCoalescent>`__
for more information.
"""
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
return N
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
:param b: Number of lineages.
:param k: Number of lineages that merge.
:return: The rate.
"""
# two lineages can merge with a rate depending on b
if k == 2:
return b * (b - 1) / 2
# no other mergers can happen
return 0
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
# if we have a single class
if len(b) == 1:
return self._get_rate(b=b[0], k=k[0])
# if we have a merger from two classes
if len(b) == 2:
if k[0] == 1 and k[1] == 1:
# same as b[0] choose k[0] times b[1] choose k[1]
return b[0] * b[1]
# no other mergers possible
return 0
def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
"""
Coalesce a state.
:param n: The total number of lineages.
:param blocks: The lineages in each block.
:return: List of coalesced states and their rates.
:meta private:
"""
n_blocks = len(blocks)
states = []
# lineage-counting state space
if n_blocks == 1:
if blocks[0] > 1:
states += [(np.array([blocks[0] - 1]), self._get_rate(b=blocks[0], k=2))]
return states
# block-counting state space
for i, j in itertools.product(range(n_blocks), repeat=2):
if i == j:
if blocks[i] > 1:
new = blocks.copy()
new[i] -= 2
new[2 * (i + 1) - 1] += 1
states += [(new, self._get_rate_block_counting(n=n, b=[blocks[i]], k=[2]))]
elif i > j:
if blocks[i] > 0 and blocks[j] > 0:
new = blocks.copy()
new[i] -= 1
new[j] -= 1
new[i + j + 1] += 1
rate = self._get_rate_block_counting(n=n, b=[blocks[i], blocks[j]], k=[1, 1])
states += [(new, rate)]
return states
def __eq__(self, other):
"""
Check if two coalescent models are equal.
:param other: The other coalescent model.
:return: Whether the two coalescent models are equal.
"""
return isinstance(other, StandardCoalescent)
class MultipleMergerCoalescent(CoalescentModel, ABC):
"""
Base class for multiple merger coalescent models.
:meta private:
"""
def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
"""
Coalesce a state.
:param n: The total number of lineages.
:param blocks: The lineages in each block.
:return: List of coalesced states and their rates.
:meta private:
"""
n_blocks = len(blocks)
states = []
# lineage-counting state space
if n_blocks == 1:
for k in range(1, blocks[0]):
states += [(np.array([blocks[0] - k]), self._get_rate(b=blocks[0], k=k + 1))]
return states
# block-counting state space
for comb in itertools.product(*[list(range(blocks[i] + 1)) for i in range(n_blocks)]):
comb = np.array(comb)
if comb.sum() > 1:
new = blocks.copy()
new -= comb
new[comb.dot(np.arange(1, n_blocks + 1)) - 1] += 1
rate = self._get_rate_block_counting(n=blocks.sum(), b=blocks[comb > 0], k=comb[comb > 0])
states += [(new, rate)]
return states
[docs]
class BetaCoalescent(MultipleMergerCoalescent):
"""
Beta coalescent model. Refer to the
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=beta+coalescent#msprime.BetaCoalescent>`__
for more information.
"""
[docs]
def __init__(self, alpha: float, scale_time: bool = True):
"""
Initialize the beta coalescent model.
:param alpha: The alpha parameter of the beta coalescent model.
:param scale_time: Whether to scale coalescence time as described in
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
highlight=beta+coalescent#msprime.BetaCoalescent>`__. If ``False``, the timescale is set to N.
"""
if not 1 < alpha < 2:
raise ValueError("Alpha must be between 1 and 2.")
#: Whether to scale coalescence time.
self.scale_time: bool = scale_time
#: The alpha parameter of the beta coalescent model.
self.alpha: float = alpha
def _get_base_rate(self, b: int, k: int) -> float:
"""
Get base rate for a merger of k out of b lineages (without number of ways).
:param b: The number of lineages before the merger.
:param k: The number of lineages that merge.
:return: The rate.
"""
rate = beta(k - self.alpha, b - k + self.alpha) / beta(self.alpha, 2 - self.alpha)
return rate
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
if not self.scale_time:
return N
m = 1 + 1 / 2 ** (self.alpha - 1) / (self.alpha - 1)
scale = m ** self.alpha * N ** (self.alpha - 1) / self.alpha / beta(2 - self.alpha, self.alpha)
return scale
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
Negative rates will be filled in later.
:param b: The number of lineages before the merger.
:param k: The number of lineages that merge.
:return: The rate.
"""
if k < 1 or k > b:
return 0
return comb(b, k, exact=True) * self._get_base_rate(b, k)
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
combinations = np.prod([comb(N=b_i, k=k_i, exact=True) for b_i, k_i in zip(b, k)])
return combinations * self._get_base_rate(b=n, k=sum(k))
def __eq__(self, other):
"""
Check if two coalescent models are equal.
:param other: The other coalescent model.
:return: Whether the two coalescent models are equal.
"""
return (
isinstance(other, BetaCoalescent) and
self.alpha == other.alpha and
self.scale_time == other.scale_time
)
[docs]
class DiracCoalescent(MultipleMergerCoalescent):
"""
Dirac coalescent model. Refer to the
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=dirac+coalescent#msprime.DiracCoalescent>`__
for more information.
"""
[docs]
def __init__(self, psi: float, c: float, scale_time: bool = True):
"""
Initialize the Dirac coalescent model.
:param psi: The fraction of the population replaced by offspring in one large reproduction event.
:param c: The rate of potential multiple merger events.
:param scale_time: Whether to scale coalescence time as described in
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
highlight=dirac+coalescent#msprime.DiracCoalescent>`__. If `False`, the timescale is set to N.
"""
super().__init__()
if not 0 < psi < 1:
raise ValueError("Psi must be between 0 and 1.")
#: The fraction of the population replaced by offspring in one large reproduction event.
self.psi: float = psi
#: The rate of potential multiple merger events.
self.c: float = c
#: Whether to scale coalescence time.
self.scale_time: bool = scale_time
#: The standard coalescent model.
self._standard = StandardCoalescent()
def _get_timescale(self, N: float) -> float:
"""
Get the timescale.
:param N: The effective population size.
:return: The generation time.
"""
if not self.scale_time:
return N
return N ** 2
def _get_rate(self, b: int, k: int) -> float:
"""
Get positive rate for a merger of k out of b lineages.
Negative rates will be filled in later.
:param b: The number of lineages before the merger.
:param k: The number of lineages that merge.
:return: The rate.
"""
# rate of binary merger
rate_binary = self._standard._get_rate(b=b, k=k)
# probability of multiple merger of k out of b lineages
p_psi = binom.pmf(k=k, n=b, p=self.psi)
# rate of multiple merger
rate_multi = p_psi * self.c
return rate_binary + rate_multi
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
"""
Get positive rate for a merger of k_i out of b_i lineages for all i.
:param n: Number of lineages currently present in the block configuration.
:param b: Number of lineages before merger for blocks that experience a merger.
:param k: Number of lineages that merge for blocks that experience a merger.
:return: The rate.
"""
# rate of binary merger
rate_binary = self._standard._get_rate_block_counting(n=n, b=b, k=k)
# probability of multiple merger of k out of n lineages
# p_psi = binom.pmf(k=k.sum(), n=n, p=self.psi)
p_psi = np.prod([binom.pmf(k=k[i], n=b[i], p=self.psi) for i in range(len(k))])
# account for probability of no merger
if sum(b) < n:
p_psi *= binom.pmf(k=0, n=n - sum(b), p=self.psi)
# rate of multiple merger
rate_multi = p_psi * self.c
rate = rate_binary + rate_multi
return rate
def __eq__(self, other):
"""
Check if two coalescent models are equal.
:param other: The other coalescent model.
:return: Whether the two coalescent models are equal.
"""
return (
isinstance(other, DiracCoalescent) and
self.psi == other.psi and
self.c == other.c and
self.scale_time == other.scale_time
)