Source code for phasegen.distributions

"""
Probability distributions.
"""

import copy
import functools
import itertools
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping
from functools import cached_property, cache
from math import factorial
from typing import Generator, List, Callable, Tuple, Dict, Collection, Iterable, Iterator, Optional, Sequence, Set, \
    Type, Union

import numpy as np
from scipy.ndimage import gaussian_filter1d
from tqdm import tqdm

from .coalescent_models import StandardCoalescent, CoalescentModel, BetaCoalescent, DiracCoalescent
from .demography import Demography, PopSizeChanges
from .expm import Backend
from .lineage import LineageConfig
from .locus import LocusConfig
from .rewards import Reward, TreeHeightReward, TotalBranchLengthReward, UnfoldedSFSReward, DemeReward, UnitReward, \
    LocusReward, CombinedReward, FoldedSFSReward, SFSReward, CustomReward
from .serialization import Serializable
from .settings import Settings
from .spectrum import SFS, SFS2
from .state_space import BlockCountingStateSpace, LineageCountingStateSpace, StateSpace
from .utils import parallelize, multiset_permutations

expm = Backend.expm

logger = logging.getLogger('phasegen')


def _make_hashable(func: Callable) -> Callable:
    """
    Decorator that makes a function hashable by converting non-hashable arguments to hashable ones.
    """

    @functools.wraps(func)
    def wrapper(self, *args: tuple, **kwargs: dict):
        """
        Wrapper function.

        :param self: Self.
        :return: The result of the function.
        """
        args = list(args)

        for i, arg in enumerate(args):
            if isinstance(arg, (list, np.ndarray)):
                args[i] = tuple(arg)

        for key, value in kwargs.items():
            if isinstance(value, (list, np.ndarray)):
                kwargs[key] = tuple(value)

        return func(self, *args, **kwargs)

    return wrapper


class ProbabilityDistribution(ABC):
    """
    Abstract base class for probability distributions for which moments can be calculated.
    """

    def __init__(self):
        """
        Create object.
        """
        #: Logger
        self._logger = logger.getChild(self.__class__.__name__)

    def touch(self, **kwargs):
        """
        Touch all cached properties.

        :param kwargs: Additional keyword arguments.
        """
        for cls in self.__class__.__mro__:
            for attr, value in cls.__dict__.items():
                if isinstance(value, cached_property):
                    getattr(self, attr)


class MomentAwareDistribution(ProbabilityDistribution, ABC):
    """
    Abstract base class for probability distributions for which moments can be calculated.
    """

    @abstractmethod
    @cached_property
    def mean(self) -> float:
        """
        First moment / mean.
        """
        pass

    @abstractmethod
    @cached_property
    def var(self) -> float:
        """
        Second central moment / variance.
        """
        pass

    @abstractmethod
    @cached_property
    def m2(self) -> float:
        """
        Second (non-central) moment.
        """
        pass


class MarginalDistributions(Mapping, ABC):
    """
    Base class for marginal distributions.
    """

    @abstractmethod
    @cached_property
    def cov(self) -> np.ndarray:
        """
        Covariance matrix.
        """
        pass

    @abstractmethod
    @cached_property
    def corr(self) -> np.ndarray:
        """
        Correlation matrix.
        """
        pass

    @abstractmethod
    def get_cov(self, d1, d2) -> float:
        """
        Get the covariance between two marginal distributions.

        :param d1: The index of the first marginal distribution.
        :param d2: The index of the second marginal distribution.
        :return: covariance
        """
        pass

    @abstractmethod
    def get_corr(self, d1, d2) -> float:
        """
        Get the correlation coefficient between two marginal distributions.

        :param d1: The index of the first marginal distribution.
        :param d2: The index of the second marginal distribution.
        :return: correlation coefficient
        """
        pass


[docs] class MarginalLocusDistributions(MarginalDistributions): """ Marginal locus distributions. """
[docs] def __init__(self, dist: 'PhaseTypeDistribution'): """ Initialize the distributions. :param dist: The distribution. """ self.dist = dist
def __getitem__(self, item): """ Get the distribution for the given locus. :param item: Deme name. :return: Distribution. """ return self.loci[item] def __iter__(self) -> Iterator: """ Iterate over distributions. :return: Iterator. """ return iter(self.loci) def __len__(self) -> int: """ Get the number of distributions. :return: Number of distributions. """ return len(self.loci) @cached_property def loci(self) -> 'MarginalLocusDistributions': """ Distributions marginalized over loci. """ # get class of distribution but use PhaseTypeDistribution # if this is a TreeHeightDistribution as TreeHeightDistribution # only works with default rewards cls = self.dist.__class__ if not isinstance(self.dist, TreeHeightDistribution) else PhaseTypeDistribution loci = {} for locus in range(self.dist.locus_config.n): loci[locus] = cls( state_space=self.dist.state_space, tree_height=self.dist.tree_height, demography=self.dist.demography, reward=CombinedReward([self.dist.reward, LocusReward(locus)]) ) return loci
[docs] def get_cov(self, locus1: int, locus2: int) -> float: """ Get the covariance between two loci. :param locus1: The first locus. :param locus2: The second locus. :return: The covariance. """ locus1 = int(locus1) locus2 = int(locus2) if locus1 not in range(self.dist.locus_config.n) or locus2 not in range(self.dist.locus_config.n): raise ValueError(f"Locus {locus1} or {locus2} does not exist.") return self.dist.moment( k=2, rewards=( CombinedReward([self.dist.reward, LocusReward(locus1)]), CombinedReward([self.dist.reward, LocusReward(locus2)]) ), center=True )
@cached_property def cov(self) -> np.ndarray: """ Covariance matrix across loci. """ n_loci = self.dist.locus_config.n return np.array([[self.get_cov(i, j) for i in range(n_loci)] for j in range(n_loci)])
[docs] def get_corr(self, locus1: int, locus2: int) -> float: """ Get the correlation coefficient between two loci. :param locus1: The first locus. :param locus2: The second locus. :return: The correlation coefficient. """ locus1 = int(locus1) locus2 = int(locus2) return self.get_cov(locus1, locus2) / (self.loci[locus1].std * self.loci[locus2].std)
@cached_property def corr(self) -> np.ndarray: """ Correlation matrix across loci. """ n_loci = self.dist.locus_config.n return np.array([[self.get_corr(i, j) for i in range(n_loci)] for j in range(n_loci)])
[docs] class MarginalDemeDistributions(MarginalDistributions): """ Marginal deme distributions. """
[docs] def __init__(self, dist: 'PhaseTypeDistribution'): """ Initialize the distributions. :param dist: The distribution. """ self.dist = dist
def __getitem__(self, item): """ Get the distribution for the given deme. :param item: Deme name. :return: Distribution. """ return self.demes[item] def __iter__(self) -> Iterator: """ Iterate over distributions. :return: Iterator. """ return iter(self.demes) def __len__(self) -> int: """ Get the number of distributions. :return: Number of distributions. """ return len(self.demes) @cached_property def demes(self) -> 'MarginalDemeDistributions': """ Distributions marginalized over demes. """ # get class of distribution but use PhaseTypeDistribution # if this is a TreeHeightDistribution as TreeHeightDistribution # only works with default rewards cls = self.dist.__class__ if not isinstance(self.dist, TreeHeightDistribution) else PhaseTypeDistribution demes = {} for pop in self.dist.lineage_config.pop_names: demes[pop] = cls( state_space=self.dist.state_space, tree_height=self.dist.tree_height, demography=self.dist.demography, reward=CombinedReward([self.dist.reward, DemeReward(pop)]) ) return demes
[docs] def get_cov(self, pop1: str, pop2: str) -> float: """ Get the covariance between two demes. :param pop1: The first deme. :param pop2: The second deme. :return: The covariance. """ if pop1 not in self.dist.lineage_config.pop_names or pop2 not in self.dist.lineage_config.pop_names: raise ValueError(f"Population {pop1} or {pop2} does not exist.") return self.dist.moment( k=2, rewards=( CombinedReward([self.dist.reward, DemeReward(pop1)]), CombinedReward([self.dist.reward, DemeReward(pop2)]) ), center=True )
@cached_property def cov(self) -> np.ndarray: """ Covariance matrix across demes. """ pops = self.dist.lineage_config.pop_names return np.array([[self.get_cov(p1, p2) for p1 in pops] for p2 in pops])
[docs] def get_corr(self, pop1: str, pop2: str) -> float: """ Get the correlation coefficient between two demes. :param pop1: The first deme. :param pop2: The second deme. :return: The correlation coefficient. """ return self.get_cov(pop1, pop2) / (self.demes[pop1].std * self.demes[pop2].std)
@cached_property def corr(self) -> np.ndarray: """ Correlation matrix across demes. """ pops = self.dist.lineage_config.pop_names return np.array([[self.get_corr(p1, p2) for p1 in pops] for p2 in pops])
class DensityAwareDistribution(MomentAwareDistribution, ABC): """ Abstract base class for probability distributions for which moments and densities can be calculated. """ @abstractmethod def cdf(self, t: float | Sequence[float]) -> float | np.ndarray: """ Cumulative distribution function. :param t: Value or values to evaluate the CDF at. :return: CDF. """ pass @abstractmethod def quantile(self, q: float) -> float: """ Get the qth quantile. """ pass @abstractmethod def pdf(self, t: float | Sequence[float], **kwargs) -> float | np.ndarray: """ Density function. :param t: Value or values to evaluate the density function at. :param kwargs: Additional keyword arguments. :return: Density. """ pass def plot_cdf( self, ax: 'plt.Axes' = None, t: np.ndarray = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = 'Tree height CDF' ) -> 'plt.Axes': """ Plot cumulative distribution function. :param ax: Axes to plot on. :param t: Values to evaluate the CDF at. By default, 200 evenly spaced values between 0 and the 99th percentile. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ from .visualization import Visualization if t is None: t = np.linspace(0, self.quantile(0.99), 200) return Visualization.plot( ax=ax, x=t, y=self.cdf(t), xlabel='t', ylabel='F(t)', label=label, file=file, show=show, clear=clear, title=title ) def plot_pdf( self, ax: 'plt.Axes' = None, t: np.ndarray = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = 'Tree height PDF', dx: float = None ) -> 'plt.Axes': """ Plot density function. :param ax: The axes to plot on. :param t: Values to evaluate the density function at. By default, 200 evenly spaced values between 0 and the 99th percentile. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :param dx: Step size for numerical differentiation. By default, the 99th percentile divided by 1e10. :return: Axes. """ from .visualization import Visualization if dx is None: dx = self.quantile(0.99) / 1e10 if t is None: t = np.linspace(0, self.quantile(0.99), 200) return Visualization.plot( ax=ax, x=t, y=self.pdf(t, dx=dx), xlabel='t', ylabel='f(t)', label=label, file=file, show=show, clear=clear, title=title )
[docs] class PhaseTypeDistribution(MomentAwareDistribution): """ Phase-type distribution for a piecewise time-homogeneous process. """
[docs] def __init__( self, state_space: StateSpace, tree_height: 'TreeHeightDistribution', demography: Demography = None, reward: Reward = None ): """ Initialize the distribution. :param state_space: The state space. :param tree_height: The tree height distribution. :param demography: The demography. :param reward: The reward. By default, the tree height reward. """ if demography is None: demography = Demography() if reward is None: reward = TreeHeightReward() super().__init__() #: Population configuration self.lineage_config: LineageConfig = state_space.lineage_config #: Locus configuration self.locus_config: LocusConfig = state_space.locus_config #: Reward self.reward: Reward = reward #: State space self.state_space: StateSpace = state_space #: Demography self.demography: Demography = demography #: Tree height distribution self.tree_height: TreeHeightDistribution = tree_height
@staticmethod def _get_van_loan_matrix(R: List[np.ndarray], S: np.ndarray, k: int = 1) -> np.ndarray: """ Get the block matrix for the given reward matrices and transition matrix. :param R: List of length k of reward matrices :param S: Intensity matrix :param k: The order of the moment. :return: Van Loan matrix which is a block matrix of size (k + 1) * (k + 1) """ # matrix of zeros O = np.zeros_like(S) # create compound matrix return np.block([[S if i == j else R[i] if i == j - 1 else O for j in range(k + 1)] for i in range(k + 1)]) @cached_property def mean(self) -> float | SFS: """ First moment / mean. """ return self.moment(k=1) @cached_property def var(self) -> float | SFS: """ Second central moment / variance. """ return self.moment(k=2, center=True) @cached_property def std(self) -> float | SFS: """ Standard deviation. """ return self.var ** 0.5 @cached_property def m2(self) -> float | SFS: """ Second (non-central) moment. """ return self.moment(k=2, center=False) @cached_property def demes(self) -> MarginalDemeDistributions: """ Marginal distributions over each deme. """ return MarginalDemeDistributions(self) @cached_property def loci(self) -> MarginalLocusDistributions: """ Marginal distributions over each locus. """ return MarginalLocusDistributions(self)
[docs] @_make_hashable @cache def moment( self, k: int, rewards: Sequence[Reward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> float: """ Get the kth (non-central) (cross-)moment. :param k: The order of the moment. :param rewards: Iterable of k rewards. By default, the reward of the underlying distribution. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth moment """ if start_time is None: start_time = self.tree_height.start_time if end_time is None: end_time = self.tree_height.t_max if start_time > 0: m_start, m_end = PhaseTypeDistribution.accumulate( self, k=k, end_times=[start_time, end_time], rewards=rewards, center=center, permute=permute ) m = float(m_end - m_start) else: m = float(PhaseTypeDistribution.accumulate( self, k=k, end_times=[end_time], rewards=rewards, center=center, permute=permute )[0]) if np.isnan(m): raise ValueError( "NaN value encountered when computing moment. " "This is likely due to an ill-conditioned rate matrix." ) return m
@staticmethod def _get_regularization_factor(S: np.ndarray) -> float: """ Get the regularization factor for the given intensity matrix. We multiply the intensity matrix by this factor to improve numerical stability when computing the matrix exponential of the Van Loan matrix. If regularization is disabled, this factor is 1. :param S: Intensity matrix. :return: Regularization factor. """ if not Settings.regularize: return 1.0 # obtain positive rates rates = S[S > 0] # rewards in the Van Loan matrix are of order 1 return 10 ** - np.log10(rates).mean() def _check_numerical_stability(self, S: np.ndarray, epoch: int): """ Warn about potential numerical instability with very small or very large rates. :param S: (Regularized) intensity matrix. :param epoch: Epoch number. """ rates = S[S > 0] if rates.min() / rates.max() < 1e-10: self._logger.warning( f"Intensity matrix in epoch {epoch} contains rates that differ by more than 10 orders of magnitude: " f"min: {rates.min()}, max: {rates.max()}. " f"This may lead to numerical instability, despite matrix regularization." )
[docs] def accumulate( self, k: int, end_times: Iterable[float], rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Evaluate the kth moment at different end times. :param k: The order of the moment. :param end_times: List of ends times or end time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The moment accumulated at the specified times or time. """ k = int(k) if rewards is None: rewards = [self.reward] * k if k != len(rewards): raise ValueError(f"Number of specified rewards for moment of order {k} must be {k}.") if k == 0: return np.ones_like(list(end_times)) # center moments around the mean if center and k > 1: components = [] # first order moments means = [ PhaseTypeDistribution.accumulate( self, k=1, rewards=(rewards[i],), end_times=end_times ) for i in range(k) ] for i in range(k + 1): # iterate over all possible subsets of rewards of size i for indices in itertools.combinations(range(k), i): # joint moment mu_i = PhaseTypeDistribution.accumulate( self, k=i, rewards=tuple(rewards[j] for j in indices), end_times=end_times, center=False, permute=permute ) # product of means of remaining rewards mu1 = np.prod([means[j] for j in range(k) if j not in indices], axis=0) components += [(-1) ** (k - i) * mu_i * mu1] return np.sum(components, axis=0) if permute: # get all possible permutations of rewards permutations = list(itertools.permutations(rewards)) # compute average over all permutations return np.sum([self._accumulate(k, tuple(end_times), r) for r in permutations], axis=0) / len(permutations) return self._accumulate(k, tuple(end_times), rewards)
@_make_hashable @cache def _accumulate_flattened( self, k: int, end_times: Sequence[float], rewards: Sequence[Reward] = None ) -> np.ndarray: """ Evaluate the kth (non-central) moment at different end times using the lineage counting state space. :param k: The order of the moment. :param end_times: Sequence of end times or end time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :return: The moment accumulated at the specified times or time. :raises ValueError: If the state space is not a BlockCountingStateSpace, or if k is not 1, or if there are multiple populations or loci or if the coalescent model is not the standard coalescent. """ if not isinstance(self.state_space, BlockCountingStateSpace): raise ValueError("Flattened accumulation is only supported for BlockCountingStateSpace.") if k != 1: raise ValueError("Flattened accumulation is only supported for k = 1.") if self.lineage_config.n_pops != 1 or self.locus_config.n != 1: raise ValueError("Flattened accumulation is only supported for a single population and a single locus.") if not isinstance(self.state_space.model, StandardCoalescent): raise ValueError("Flattened accumulation is only supported for standard coalescent.") reward = rewards[0] if rewards else self.reward r = reward._get(self.state_space) probs = self.state_space._state_probs # sum up weights for each state based on the number of lineages n = self.lineage_config.n weights = np.zeros(n) for i, s in enumerate(self.state_space.states): weights[n - s.lineages.sum()] += probs[i] * r[i] # Create a custom reward that returns the weights. weighted_reward = CustomReward(lambda _: weights) return self.tree_height._accumulate(k=k, end_times=end_times, rewards=(weighted_reward,)) @_make_hashable @cache def _accumulate( self, k: int, end_times: Sequence[float], rewards: Sequence[Reward] = None ) -> np.ndarray: """ Evaluate the kth (non-central) moment at different end times. :param k: The order of the moment. :param end_times: Sequence of ends times or end time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :return: The moment accumulated at the specified times or time. """ if ( Settings.flatten_block_counting and k == 1 and isinstance(self.state_space, BlockCountingStateSpace) and isinstance(self.state_space.model, StandardCoalescent) and self.lineage_config.n_pops == 1 and self.locus_config.n == 1 ): return self._accumulate_flattened(k, end_times, rewards) end_times = np.array(end_times) # check for negative values if np.any(end_times < 0): raise ValueError("Negative end times are not allowed.") # use default reward if not specified if rewards is None: rewards = (self.reward,) * k else: if len(rewards) != k: raise ValueError(f"Number of rewards must be {k}.") # sort array in ascending order but keep track of original indices t_sorted: Collection[float] = np.sort(end_times) epochs = enumerate(self.demography.epochs) i_epoch, epoch = next(epochs) # get state space for the first epoch self.state_space.update_epoch(epoch) # number of states n_states = self.state_space.k # initialize block matrix holding (rewarded) moments Q = np.eye(n_states * (k + 1)) u_prev = 0 # initialize probabilities moments = np.zeros_like(t_sorted, dtype=float) # regularization parameter lamb = self._get_regularization_factor(self.state_space.S) # regularized intensity matrix S = self.state_space.S * lamb # check numerical stability self._check_numerical_stability(S, 0) # get reward matrix R = [np.diag(r._get(state_space=self.state_space)) for r in rewards] # get Van Loan matrix V = self._get_van_loan_matrix(S=S, R=R, k=k) # iterate through sorted values for i, u in enumerate(t_sorted): # iterate over epochs between u_prev and u while u > epoch.end_time: # update transition matrix with remaining time in current epoch Q @= expm(V * (epoch.end_time - u_prev) / lamb) # fetch and update for next epoch u_prev = epoch.end_time i_epoch, epoch = next(epochs) self.state_space.update_epoch(epoch) # compute Van Loan matrix for next epoch using regularized intensity matrix S = self.state_space.S * lamb self._check_numerical_stability(S, 0) V = self._get_van_loan_matrix(S=S, R=R, k=k) # update with remaining time in current epoch Q @= expm(V * (u - u_prev) / lamb) alpha = self.state_space.alpha e = self.state_space.e moments[i] = factorial(k) * lamb ** k * alpha @ Q[:n_states, -n_states:] @ e u_prev = u # sort probabilities back to original order moments = moments[np.argsort(end_times)] if np.isnan(moments).any(): self._logger.warning( "NaN values encountered when computing moments. " f"Epoch: {i_epoch} at time: {epoch.start_time}. " "This is likely due to an ill-conditioned rate matrix." ) return moments def _sample( self, n_samples: int, rewards: Sequence[Reward] = None, record_visits: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Generate samples from the mean reward distribution by simulating trajectories. :param n_samples: Number of trajectories to simulate. :param rewards: Rewards to sample from. Default is the tree height reward. :param record_visits: Whether to record which states were visited during the sampling. :return: Array of sampled rewards of size (n_samples, len(rewards)), and optionally an array of probabilities of visiting each state. """ if rewards is None: rewards = [self.reward] n_rewards = len(rewards) samples = np.zeros((n_samples, n_rewards)) absorbing = np.array([s.is_absorbing() for s in self.state_space.states]) alpha = self.state_space.alpha states_visited = np.zeros_like(alpha) R = np.array([r._get(self.state_space) for r in rewards]) # iterate over samples for i in tqdm(range(n_samples), disable=not Settings.use_pbar): mass = np.zeros(n_rewards) t = 0 rate = 0 state = np.random.choice(len(alpha), p=alpha) epochs = self.demography.epochs traj_probs = [] if record_visits else None try: # find first non-zero rate epoch while rate == 0: epoch = next(epochs) self.state_space.update_epoch(epoch) rate = -self.state_space.S[state, state] # sample next time step dt = np.random.exponential(1 / rate) # iterate over transitions while True: # iterate over epochs while t + dt >= epoch.end_time: # reward until epoch boundary mass += R[:, state] * (epoch.end_time - t) dt -= (epoch.end_time - t) t = epoch.end_time # advance epoch epoch = next(epochs) self.state_space.update_epoch(epoch) new_rate = -self.state_space.S[state, state] if new_rate == 0: t = epoch.end_time continue # rescale remaining time dt *= rate / new_rate rate = new_rate # step completes in current epoch mass += R[:, state] * dt t += dt # sample next state probs = self.state_space.S[state].copy() probs[state] = 0 state = np.random.choice(len(probs), p=probs / rate) states_visited[state] += 1 if absorbing[state]: raise StopIteration rate = -self.state_space.S[state, state] # if rate is zero, we skip to the next epoch if rate == 0: t = epoch.end_time continue # sample next time step dt = np.random.exponential(1 / rate) except StopIteration: pass samples[i] = mass # normalize states visited states_visited /= n_samples return (samples, states_visited) if record_visits else samples
[docs] def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = None ) -> 'plt.Axes': """ Plot accumulation of (non-central) moments at different times. .. note:: This is different from a CDF, as it shows the accumulation of moments rather than the probability of having reached absorption at a certain time. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :param ax: The axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ k = int(k) from .visualization import Visualization if end_times is None: end_times = np.linspace(0, self.tree_height.quantile(0.99), 200) if rewards is None: rewards = (self.reward,) * k if title is None: title = f"Moment accumulation ({', '.join(r.__class__.__name__.replace('Reward', '') for r in rewards)})" y = self.accumulate(k, end_times, rewards, center, permute) Visualization.plot( ax=ax, x=end_times, y=y, xlabel='t', ylabel='moment', label=label, file=file, show=show, clear=clear, title=title )
[docs] class TreeHeightDistribution(PhaseTypeDistribution, DensityAwareDistribution): """ Phase-type distribution for a piecewise time-homogeneous process that allows the computation of the density function. This is currently only possible with default rewards. """ #: Maximum number of epochs to consider when determining time to almost sure absorption. max_epochs: int = 10000 #: Maximum number of time we double the end time when determining time to almost sure absorption. max_iter: int = 20 #: Probability of almost sure absorption. p_absorption: float = 1 - 1e-15
[docs] def __init__( self, state_space: LineageCountingStateSpace, demography: Demography = None, start_time: float = 0, end_time: float = None ): """ Initialize the distribution. :param state_space: The state space. :param demography: The demography. :param start_time: Time when to start accumulating moments. :param end_time: Time when to end accumulation of moments. By default, the time until almost sure absorption. """ if start_time < 0: raise ValueError("Start time must be greater than or equal to 0.") if end_time is not None and end_time < 0: raise ValueError("End time must be greater than or equal to 0.") if end_time is not None and end_time < start_time: raise ValueError("End time must be greater than equal start time.") super().__init__( state_space=state_space, tree_height=self, demography=demography, reward=TreeHeightReward() ) #: State space self.state_space: LineageCountingStateSpace = state_space #: Start time self.start_time: float = start_time #: End time self.end_time: float | None = end_time
[docs] def cdf(self, t: float | Sequence[float]) -> float | np.ndarray: """ Cumulative distribution function. :param t: Value or values to evaluate the CDF at. :return: Cumulative probability :raises NotImplementedError: if rewards are not default """ # raise error if rewards are not default if not isinstance(self.reward, TreeHeightReward): raise NotImplementedError("PDF not implemented for non-default rewards.") # assume scalar if not array if not isinstance(t, Iterable): return self.cdf(np.array([t]))[0] # check for negative values if np.any(t < 0): raise ValueError("Negative values are not allowed.") # sort array in ascending order but keep track of original indices t_sorted: Collection[float] = np.sort(t).astype(float) epochs = enumerate(self.demography.epochs) i_epoch, epoch = next(epochs) # get the transition matrix for the first epoch self.state_space.update_epoch(epoch) # initialize transition matrix T = np.eye(self.state_space.k) u_prev = 0 # initialize probabilities probs = np.zeros_like(t_sorted) # take reward vector as exit vector e = self.reward._get(self.state_space) # iterate through sorted values for i, u in enumerate(t_sorted): # iterate over epochs between u_prev and u while u > epoch.end_time: self._check_numerical_stability(self.state_space.S, i_epoch) # update transition matrix with remaining time in current epoch T @= expm(self.state_space.S * (epoch.end_time - u_prev)) # fetch and update for next epoch u_prev = epoch.end_time i_epoch, epoch = next(epochs) self.state_space.update_epoch(epoch) self._check_numerical_stability(self.state_space.S, i_epoch) # update transition matrix with remaining time in current epoch T @= expm(self.state_space.S * (u - u_prev)) probs[i] = 1 - self.state_space.alpha @ T @ e u_prev = u # sort probabilities back to original order probs = probs[np.argsort(t)] if np.isnan(probs).any(): self._logger.critical( "NaN values in CDF. This is likely due to an ill-conditioned rate matrix." ) return probs
def _update( self, u: float, u_prev: float, T: np.ndarray, epoch: 'Epoch' ) -> Tuple[float, np.ndarray, 'Epoch']: """ Update transition matrix and time. :param u: Time to update to. :param u_prev: Previous time. :param T: Transition matrix. :param epoch: Current epoch. :return: Updated time, transition matrix, and epoch. """ self.state_space.update_epoch(epoch) while u > epoch.end_time: # update transition matrix with remaining time in current epoch tau = epoch.end_time - u_prev T = T @ expm(self.state_space.S * tau) u_prev = epoch.end_time # fetch and update for next epoch epoch = self.demography.get_epoch(epoch.end_time) self.state_space.update_epoch(epoch) else: # update transition matrix T = T @ expm(self.state_space.S * (u - u_prev)) return u, T, epoch @cached_property def _e(self) -> np.ndarray: """ Exit vector. """ return self.reward._get(self.state_space) def _cum(self, T: np.ndarray) -> float: """ Get cumulative probability for given transition matrix. :param T: Transition matrix. :return: Cumulative probability. """ return float(1 - self.state_space.alpha @ T @ self._e)
[docs] @cache def quantile( self, q: float, expansion_factor: float = 2, precision: float = 1e-5, max_iter: int = 1000 ): """ Find the specified quantile of a CDF using an adaptive bisection method. :param q: The desired quantile (between 0 and 1). :param expansion_factor: Factor by which to multiply the upper bound that does not yet contain the quantile. :param precision: Precision for quantile, i.e. ``F(b) - F(a) < precision``. :param max_iter: Maximum number of iterations. :return: The quantile. """ if q < 0 or q > 1: raise ValueError("Specified quantile must be between 0 and 1.") if expansion_factor <= 1: raise ValueError("Expansion factor must be greater than 1.") # initialize bounds a, b = 0, 1 T_a = np.eye(self.state_space.k) epoch_a, epoch_b = self.demography.get_epoch(0), self.demography.get_epoch(0) b, T_b, epoch_b = self._update(b, a, T_a, epoch_b) i = 0 # expand lower bound until it contains the quantile while self._cum(T_b) < q and i < max_iter: b, T_b, epoch_b = self._update(b * expansion_factor, b, T_b, epoch_b) i += 1 # use bisection method within the determined bounds while self._cum(T_b) - self._cum(T_a) > precision and i < max_iter: m, T_m, epoch_m = self._update((a + b) / 2, a, T_a, epoch_a) if self._cum(T_m) < q: a, T_a, epoch_a = m, T_m, epoch_m else: b, T_b, epoch_b = m, T_m, epoch_m i += 1 # warn if maximum number of iterations reached if i - 1 == max_iter: raise RuntimeError("Maximum number of iterations reached when determining quantile.") return (a + b) / 2
[docs] def pdf(self, t: float | Sequence[float], dx: float = None) -> float | np.ndarray: """ Density function. We use numerical differentiation of the CDF to calculate the density. This provides good results as the CDF is exact and continuous. :param t: Value or values to evaluate the density function at. :param dx: Step size for numerical differentiation. By default, the 99th percentile divided by 1e10. :return: Density """ if dx is None: dx = self.quantile(0.99) / 1e10 if isinstance(t, Iterable): t = np.array(t) # determine (non-negative) evaluation points x1 = np.max([t - dx / 2, np.zeros_like(t)], axis=0) x2 = x1 + dx return (self.cdf(x2) - self.cdf(x1)) / dx
@cached_property def t_max(self) -> float: """ Time until which computations are performed. This is either the end time specified when initializing the distribution or the time until almost sure absorption. """ if self.end_time is not None: return self.end_time t_abs = self._get_absorption_time() if t_abs < self.start_time: raise ValueError( f"Determined time of almost sure absorption ({t_abs:.1f}) " f"is smaller than start time ({self.start_time:.1f}). " "The start time may be too large or the demography not well-defined." ) return t_abs def _get_absorption_time(self) -> float: """ Get a time estimate for when we have reached absorption almost surely. We base this computation on the transition matrix rather than the moments, because here we have a good idea about how likely absorption, and can warn the user if necessary. Stopping the computation when no more rewards are accumulated is not a good idea, as this can happen before almost sure absorption (exponential runaway growth, temporary isolation in different demes). """ i = 0 T = np.eye(self.state_space.k) epoch = self.demography.get_epoch(0) t = 2 ** int(np.log2(np.mean(list(epoch.pop_sizes.values())))) expansion_factor = 2 t, T, epoch = self._update(t, 0, T, epoch) p = self._cum(T) # multiple time by expansion_factor until we reach p_absorption while p < self.p_absorption and i < self.max_iter: t, T, epoch = self._update(t * expansion_factor, t, T, epoch) p = self._cum(T) if np.isnan(p): self._logger.critical( "Could not reliably find time of almost sure absorption " "as probability of absorption is NaN. " "This is likely due to an ill-conditioned rate matrix. " f"Using time {t:.1f}. " ) i += 1 if i - 1 == self.max_iter: self._logger.warning( "Could not reliably find time of almost sure absorption after maximum number of iterations. " f"Using time {t:.1f} with probability of absorption 1 - {1 - p:.1e}. " "This could be due to numerical imprecision, unreachable states or very large or small " "absorption times. You can set the end time manually (see `Coalescent.end_time`) or increase " "the maximum number of iterations (`TreeHeightDistribution.max_iter`)." ) return t def _empirical_cdf(self, n_samples: int, reward: Reward = None, t: float | Sequence[float] = None) -> np.ndarray: """ Generate an empirical cumulative distribution function (CDF) by sampling from the distribution. :param n_samples: Number of samples to generate. :param reward: Reward function to use for sampling. If not specified, the default reward of the distribution is used. :param t: Values at which to evaluate the CDF. Default to 100 evenly spaced values between 0 and the 99th percentile. :return: Sorted array of sampled total rewards. """ if t is None: t = np.linspace(0, self.tree_height.quantile(0.99), 100) samples = self._sample(n_samples, reward).reshape(n_samples) x = np.sort(samples) y = np.arange(1, n_samples + 1) / n_samples if x.ndim == 1: return np.interp(t, x, y) def _plot_empirical_cdf( self, n_samples: int = 1000, reward: Reward = None, t: float | Sequence[float] = None, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = 'Empirical CDF' ) -> 'plt.Axes': """ Plot the empirical cumulative distribution function (CDF). :param n_samples: Number of samples to generate. :param reward: Reward function to use for sampling. If not specified, the default reward of the distribution is used. :param t: Values at which to evaluate the CDF. Default to 100 evenly spaced values between 0 and the 99th percentile. :param ax: Axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ from .visualization import Visualization if t is None: t = np.linspace(0, self.tree_height.quantile(0.99), 100) y = self._empirical_cdf(n_samples, reward, t) return Visualization.plot( ax=ax, x=t, y=y, xlabel='t', ylabel='F(t)', label=label, file=file, show=show, clear=clear, title=title )
class SFSDistribution(PhaseTypeDistribution, ABC): """ Base class for site-frequency spectrum distributions. """ def __init__( self, state_space: BlockCountingStateSpace, tree_height: TreeHeightDistribution, demography: Demography, reward: Reward = None ): """ Initialize the distribution. :param state_space: Block-counting state space. :param tree_height: The tree height distribution. :param demography: The demography. :param reward: The reward to multiply the SFS reward with. By default, the unit reward is used, which has no effect. """ if reward is None: reward = UnitReward() super().__init__( state_space=state_space, tree_height=tree_height, demography=demography, reward=reward ) #: Generated probability mass by iterator returned from :meth:`get_mutation_configs`. self.generated_mass = 0 @abstractmethod def _get_sfs_reward(self, i: int) -> SFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ pass @abstractmethod def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ pass @staticmethod @abstractmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ pass @_make_hashable @cache def moment( self, k: int, rewards: Sequence[SFSReward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> SFS: """ Get the kth moments of the site-frequency spectrum. :param k: The order of the moment :param rewards: Sequence of k rewards :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: A site-frequency spectrum of kth order moments. """ if rewards is None: rewards = (self.reward,) * k # optionally parallelize the moment computation of the SFS bins moments = parallelize( func=lambda x: self._moment(*x), data=[[k, i, rewards, start_time, end_time, center, permute] for i in self._get_indices()], desc=f"Calculating {k}-moments", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) return SFS([0] + list(moments) + [0] * (self.lineage_config.n - len(moments))) def _moment( self, k: int, i: int, rewards: Sequence[SFSReward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> float: """ Get the kth moment for the ith site-frequency count. :param k: The order of the moment :param i: The ith site-frequency count :param rewards: Sequence of k rewards :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth SFS (cross)-moment at the ith site-frequency count """ return PhaseTypeDistribution.moment( self, k=k, rewards=tuple([CombinedReward([r, self._get_sfs_reward(i)]) for r in rewards]), start_time=start_time, end_time=end_time, center=center, permute=permute ) def accumulate( self, k: int, end_times: Iterable[float], rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Evaluate the kth (non-central) moments for site-frequency spectrum at different end times. :param k: The order of the moment. :param end_times: Times or time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: Array of moments accumulated at the specified times, one for each site-frequency count. """ k = int(k) indices = self._get_indices() end_times = np.array(list(end_times)) accumulation = parallelize( func=lambda x: self.get_accumulation(*x), data=[[k, i, end_times, rewards] for i in indices], desc=f"Calculating accumulation of {k}-moments", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) # pad with zeros return np.concatenate([ np.zeros((1, len(end_times))), accumulation, np.zeros((self.lineage_config.n - len(indices), len(end_times))) ]) def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = None ) -> 'plt.Axes': """ Plot accumulation of (non-central) SFS moments at different times. .. note:: This is different from a CDF, as it shows the accumulation of moments rather than the probability of having reached absorption at a certain time. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :param ax: The axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ import matplotlib.pyplot as plt from .visualization import Visualization k = int(k) if ax is None: ax = plt.gca() if end_times is None: end_times = np.linspace(0, self.tree_height.quantile(0.99), 200) if rewards is None: rewards = (self.reward,) * k if title is None: title = (f"SFS Moment accumulation " f"({', '.join(r.__class__.__name__.replace('Reward', '') for r in rewards)})") # get accumulation of moments accumulation = self.accumulate(k, end_times, rewards, center, permute) for i, acc in zip(self._get_indices(), accumulation[1: -1]): Visualization.plot( ax=ax, x=end_times, y=acc, xlabel='t', ylabel='moment', label=f'{i}', file=file, show=i == self._get_indices()[-1] and show, clear=clear, title=title ) return ax def get_accumulation( self, k: int, i: int, end_times: Iterable[float] | float, rewards: Sequence[SFSReward] = None, center: bool = True, permute: bool = True ) -> np.ndarray | float: """ Get accumulation of moments for the ith site-frequency count. :param k: The order of the moment :param i: The ith site-frequency count. :param end_times: Times or time when to evaluate the moment. :param rewards: Sequence of k rewards. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth SFS (cross)-moment accumulations at the ith site-frequency count """ if rewards is None: rewards = [self.reward] * k return super().accumulate( k=k, end_times=end_times, rewards=tuple([CombinedReward([r, self._get_sfs_reward(i)]) for r in rewards]), center=center, permute=permute ) @cached_property def cov(self) -> SFS2: """ Covariance matrix across site-frequency counts. """ # create list of arguments for each combination of i, j indices = [(i, j) for i in self._get_indices() for j in self._get_indices()] # get sfs using parallelized function sfs_results = parallelize( func=lambda x: ( PhaseTypeDistribution.moment(self, k=2, permute=False, center=False, rewards=( CombinedReward([self.reward, self._get_sfs_reward(x[0])]), CombinedReward([self.reward, self._get_sfs_reward(x[1])]) )) ), data=indices, desc="Calculating covariance", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) # re-structure the results to a matrix form sfs = np.zeros((self.lineage_config.n + 1, self.lineage_config.n + 1)) for ((i, j), result) in zip(indices, sfs_results): sfs[i, j] = result # get matrix of marginal moments m2 = np.outer(self.mean.data, self.mean.data) # calculate covariances cov = (sfs + sfs.T) / 2 - m2 return SFS2(cov) def get_cov(self, i: int, j: int) -> float: """ Get the covariance between the ith and jth site-frequency. :param i: The ith frequency count :param j: The jth frequency count :return: covariance """ if i in (0, self.lineage_config.n) or j in (0, self.lineage_config.n): return 0 return super().moment( k=2, rewards=( CombinedReward([self.reward, self._get_sfs_reward(i)]), CombinedReward([self.reward, self._get_sfs_reward(j)]) ), center=True ) @cached_property def corr(self) -> SFS2: """ Correlation matrix across site-frequency counts. """ # get standard deviations std = np.sqrt(self.var.data) sfs = SFS2(self.cov.data / np.outer(std, std)) # replace NaNs with zeros sfs.data[np.isnan(sfs.data)] = 0 return sfs def get_corr(self, i: int, j: int) -> float: """ Get the correlation coefficient between the ith and jth site-frequency. :param i: The ith frequency count :param j: The jth frequency count :return: Correlation coefficient """ if i in (0, self.lineage_config.n) or j in (0, self.lineage_config.n): return 0 return self.get_cov(i, j) / (np.sqrt(self.get_cov(i, i)) * np.sqrt(self.get_cov(j, j))) @cache def _get_P(self, n: int, theta: float) -> Tuple[np.ndarray, np.ndarray]: """ Get transition matrix for mutational configuration probabilities. :param n: The number of frequency bins. :param theta: The mutation rate. :return: Transition matrix and exit vector. """ # get non-absorbing states non_absorbing = TreeHeightReward()._get(self.state_space).astype(bool) e = self.state_space.e[non_absorbing] R = np.array([self._get_sfs_reward(i)._get(self.state_space) for i in range(1, n + 1)])[:, non_absorbing] r_total = R.T @ np.ones(n) S = self.state_space.S[non_absorbing, :][:, non_absorbing] I = np.eye(S.shape[0]) P_total = np.linalg.inv(I - np.diag(1 / r_total) / theta @ S) p_total = (I - P_total) @ e P = np.array([P_total @ np.diag(R[i] / r_total) for i in range(n)]) return P, p_total def get_mutation_config(self, config: Sequence[int], theta: float) -> float: """ Get the probabilities of observing the given mutational configurations according to the infinite sites model. .. note:: This currently only works for a single epoch, i.e. a time-homogeneous demography, and recombination is not supported. :param config: The mutational configuration. A sequence of integers of length n - 1 for unfolded configurations and n // 2 for folded configurations, where n is the number of lineages. Each element in the sequence is an integer representing the number of mutations at each frequency count starting from 1. For example, the unfolded configuration [2, 1, 0] represents two singleton, one doubleton and zero tripleton mutations for a sample size of 4 lineages. Similarly, the folded configuration [2, 1] represents two singleton or tripleton and one doubleton mutation for the same number of lineages. :param theta: The mutation rate. :return: The probability of observing the given mutational configuration. """ # raise not implemented error if more than one epoch if self.demography.has_n_epochs(2): raise NotImplementedError("Sampling not implemented for more than one epoch.") # make sure theta is non-negative if theta < 0: raise ValueError("Theta must be greater than or equal to 0.") # number of frequency bins n = len(self._get_configs(self.lineage_config.n, 0)[0]) if len(config) != n: raise ValueError( "The length of the configuration must be equal to the number of frequency bins. " f"Expected {n}, got {len(config)}." ) # explicitly convert to tuple of integers config = tuple(int(c) for c in config) # handle special case when theta = 0 if theta == 0: if sum(config) == 0: return 1 return 0 # get non-absorbing states non_absorbing = TreeHeightReward()._get(self.state_space).astype(bool) # number of non-absorbing states k = non_absorbing.sum() alpha = self.state_space.alpha[non_absorbing] P, p_total = self._get_P(n, theta) q = list(itertools.chain(*[[i + 1] * j for i, j in enumerate(config)])) # iterate over permutations of q Q = np.zeros((k, k)) for p in multiset_permutations(q): U = np.eye(k) for i in p: U @= P[i - 1] Q += U p = alpha @ Q @ p_total return p def get_mutation_configs(self, theta: float) -> Iterator[Tuple[Tuple[float, ...], float]]: """ An iterator over the probabilities of observing mutational configurations according to the infinite sites model. The order of the mutational configurations generated ascends in the number of mutations observed. See :meth:`get_mutation_config` for more information on mutational configurations. .. note:: This currently only works for a single epoch, i.e. a time-homogeneous demography, and recombination is not supported. Also note that the number of configurations is infinite, so this iterator will never stop. However, depending on the mutation rate, the probability of observing configurations of higher mutation counts will decrease over time. You can keep track of the generated probability mass by checking the :attr:`~.generated_mass` attribute, which is reset every time this method is called. A good approach is thus to keep generating configurations until the generated mass is above a certain threshold. More complex demographic models and larger sample sizes increase the number of configurations and higher mutation rates, the number of generated configurations necessary to reach a certain mass. Code example: :: coal = pg.Coalescent(n=5) it = coal.sfs.get_mutation_configs(theta=1) # continue until generated mass is above 0.8 samples = list(pg.takewhile_inclusive(lambda _: coal.sfs.generated_mass < 0.8, it)) :param theta: The mutation rate. :return: An iterator over the probabilities of observing mutational configurations. """ # reset generated mass self.generated_mass = 0 # iterate over number of mutations i = 0 while True: # iterate over configurations for config in self._get_configs(self.lineage_config.n, i): p = self.get_mutation_config(config=config, theta=theta) self.generated_mass += p yield config, p # increase counter for number of mutations i += 1
[docs] class UnfoldedSFSDistribution(SFSDistribution): """ Unfolded site-frequency spectrum distribution. """ def _get_sfs_reward(self, i: int) -> UnfoldedSFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ return UnfoldedSFSReward(i) def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ return np.arange(1, self.lineage_config.n) @staticmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ return StateSpace._get_partitions(n=k, k=n - 1)
[docs] class FoldedSFSDistribution(SFSDistribution): """ Folded site-frequency spectrum distribution. """ def _get_sfs_reward(self, i: int) -> FoldedSFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ return FoldedSFSReward(i) def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ return np.arange(1, self.lineage_config.n // 2 + 1) @staticmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ return StateSpace._get_partitions(n=k, k=n // 2) def _unfold(self, config: Sequence[int]) -> Set[Tuple[int, ...]]: """ Unfold a folded configuration into all possible unfolded configurations. :param config: The folded configuration. A sequence of integers of length n // 2 where n is the number of lineages. :return: The unfolded configurations. """ n = self.lineage_config.n if n // 2 != len(config): raise ValueError("The length of the configuration must equal n // 2 where n is the number of lineages.") if n % 2 == 1: lower_counts = [range(i + 1) for i in config] i_center = len(config) else: lower_counts = [range(i + 1) for i in config[:-1]] + [[config[-1]]] i_center = len(config) - 1 unfolded = [] # iterate over unfolded configurations for lower in itertools.product(*lower_counts): # get higher counts higher = (np.array(config) - np.array(lower))[:i_center][::-1] unfolded += [list(lower) + list(higher)] return set(tuple(u) for u in unfolded)
class EmpiricalDistribution(DensityAwareDistribution): # pragma: no cover """ Probability distribution based on realisations. """ def __init__(self, samples: np.ndarray | list): """ Create object. :param samples: 1-D array of samples. """ super().__init__() self._cache = None #: Samples self.samples = np.array(samples, dtype=float) def touch(self, t: np.ndarray): """ Touch all cached properties. :param t: Times to cache properties for. """ super().touch() self._cache = dict( t=t, cdf=self.cdf(t), pdf=self.pdf(t) ) def drop(self): """ Drop simulated samples. """ self.samples = None @cached_property def mean(self) -> float | np.ndarray: """ First moment / mean. """ return np.mean(self.samples, axis=0) @cached_property def var(self) -> float | np.ndarray: """ Second central moment / variance. """ return np.var(self.samples, axis=0) @cached_property def m2(self) -> float | np.ndarray: """ Second non-central moment. """ return np.mean(self.samples ** 2, axis=0) @cached_property def m3(self) -> float | np.ndarray: """ Third non-central moment. """ return np.mean(self.samples ** 3, axis=0) @cached_property def m4(self) -> float | np.ndarray: """ Fourth non-central moment. """ return np.mean(self.samples ** 4, axis=0) @cached_property def cov(self) -> float | np.ndarray: """ Covariance matrix. """ return np.nan_to_num(np.cov(self.samples, rowvar=False)) @cached_property def corr(self) -> float | np.ndarray: """ Correlation matrix. """ return np.nan_to_num(np.corrcoef(self.samples, rowvar=False)) def moment(self, k: int) -> float | np.ndarray: """ Get the kth moment. :param k: The order of the moment :return: The kth moment """ return np.mean(self.samples ** k, axis=0) def cdf(self, t: float | Sequence[float]) -> float | np.ndarray: """ Cumulative distribution function. :param t: Time. :return: Cumulative probability. """ x = np.sort(self.samples) y = np.arange(1, len(self.samples) + 1) / len(self.samples) if x.ndim == 1: return np.interp(t, x, y) if x.ndim == 2: return np.array([np.interp(t, x_, y) for x_ in x.T]) raise ValueError("Samples must be 1 or 2 dimensional.") def quantile(self, q: float) -> float: """ Get the qth quantile. :param q: Quantile. :return: Quantile. """ return np.quantile(self.samples, q=q) def pdf( self, t: float | np.ndarray, n_bins: int = 10000, sigma: float = None, samples: np.ndarray = None ) -> float | np.ndarray: """ Density function. :param sigma: Sigma for Gaussian filter. :param n_bins: Number of bins. :param t: Time. :param samples: Samples. :return: Density. """ samples = self.samples if samples is None else samples if samples.ndim == 2: return np.array([self.pdf(t, n_bins=n_bins, sigma=sigma, samples=s) for s in samples.T]) hist, bin_edges = np.histogram(samples, range=(0, max(samples)), bins=n_bins, density=True) # determine bins for u bins = np.minimum(np.sum(bin_edges <= t[:, None], axis=1) - 1, np.full_like(t, n_bins - 1, dtype=int)) # use proper bins for y values y = hist[bins] # smooth using gaussian filter if sigma is not None: y = gaussian_filter1d(y, sigma=sigma) return y class EmpiricalSFSDistribution(EmpiricalDistribution): # pragma: no cover """ SFS probability distribution based on realisations. """ def __init__(self, samples: np.ndarray | list): """ Create object. :param samples: 2-D array of samples. """ super().__init__(samples) @cached_property def mean(self) -> SFS: """ First moment / mean. """ return SFS(super().mean) @cached_property def var(self) -> SFS: """ Second central moment / variance. """ return SFS(super().var) @cached_property def m2(self) -> SFS: """ Second non-central moment. """ return SFS(super().m2) @cached_property def cov(self) -> SFS2: """ Covariance matrix. """ return SFS2(np.nan_to_num(np.cov(self.samples, rowvar=False))) @cached_property def corr(self) -> SFS2: """ Correlation matrix. """ return SFS2(np.nan_to_num(np.corrcoef(self.samples, rowvar=False))) class DictContainer(dict): # pragma: no cover """ Dictionary container. """ pass class EmpiricalPhaseTypeDistribution(EmpiricalDistribution): # pragma: no cover """ Phase-type distribution based on realisations. """ def __init__( self, samples: np.ndarray | list, pops: List[str], locus_agg: Callable = lambda x: x.sum(axis=0) ): """ Create object. :param samples: 3-D array of samples. :param pops: List of population names. :param locus_agg: Aggregation function for loci. """ over_loci = locus_agg(samples).astype(float) over_demes = samples.sum(axis=1).astype(float) super().__init__(over_loci.sum(axis=0)) #: Population names self.pops = pops #: Samples by deme and locus self._samples = samples #: Covariance matrix for the demes self.pops_cov: np.ndarray = np.cov(over_loci) #: Correlation matrix for the demes self.pops_corr: np.ndarray = np.corrcoef(over_loci) #: Covariance matrix for the loci self.loci_corr: np.ndarray = np.corrcoef(over_demes) #: Correlation matrix for the loci self.loci_cov: np.ndarray = np.cov(over_demes) def touch(self, t: np.ndarray): """ Touch all cached properties. :param t: Times to cache properties for. """ super().touch(t) [d.touch(t) for d in self.demes.values()] [l.touch(t) for l in self.loci.values()] def drop(self): """ Drop simulated samples. """ super().drop() self._samples = None [d.drop() for d in self.demes.values()] [l.drop() for l in self.loci.values()] @cached_property def demes(self) -> Dict[str, EmpiricalDistribution]: """ Get the distribution for each deme. :return: Dictionary of distributions. """ demes = DictContainer( {pop: EmpiricalDistribution(self._samples.sum(axis=0)[i]) for i, pop in enumerate(self.pops)} ) # TODO this is the covariance in the tree height but phasegen # provides the covariance in the number of lineages per deme demes.cov = self.pops_cov demes.corr = self.pops_corr return demes @cached_property def loci(self) -> Dict[int, EmpiricalDistribution]: """ Get the distribution for each locus. :return: Dictionary of distributions. """ loci = DictContainer( {i: EmpiricalDistribution(self._samples[i].sum(axis=0)) for i in range(self._samples.shape[0])} ) loci.cov = self.loci_cov loci.corr = self.loci_corr return loci class EmpiricalPhaseTypeSFSDistribution(EmpiricalPhaseTypeDistribution): # pragma: no cover """ SFS phase-type distribution based on realisations. """ def __init__( self, branch_lengths: np.ndarray, mutations: np.ndarray, pops: List[str], sfs_dist: Type[SFSDistribution], locus_agg: Callable = lambda x: x.sum(axis=0), ): """ Create object. :param branch_lengths: 4-D array of branch length samples. :param mutations: 4-D array of mutation counts. :param pops: List of population names. :param sfs_dist: SFS distribution class. :param locus_agg: Aggregation function for loci. """ over_loci = locus_agg(branch_lengths).astype(float) EmpiricalDistribution.__init__(self, over_loci.sum(axis=0)) #: Population names self.pops = pops # : Number of lineages self.n = branch_lengths.shape[-1] - 1 #: SFS distribution class self._sfs_dist = sfs_dist #: Branch length samples by deme and locus self._samples = branch_lengths #: Mutation counts by deme and locus self._mutations = mutations #: Correlation matrix for the loci self.pops_corr = self._get_stat_pops(over_loci, np.corrcoef) #: Covariance matrix for the demes self.pops_cov: np.ndarray = self._get_stat_pops(over_loci, np.cov) #: Correlation matrix for the loci self.loci_corr: np.ndarray = None #: Covariance matrix for the loci self.loci_cov: np.ndarray = None #: Generated probability mass by iterator returned from :meth:`get_mutation_configs`. self.generated_mass = 0 def drop(self): """ Drop simulated samples. """ super().drop() self._mutations = None @staticmethod def _get_stat_pops(samples: np.ndarray, callback: Callable) -> np.ndarray: """ Get the covariance matrix for the demes. :param callback: Callback function to apply to the samples. :return: Covariance matrix. """ stats = np.zeros((samples.shape[0], samples.shape[0], samples.shape[2], samples.shape[2])) for i, j in itertools.product(range(1, samples.shape[2] - 1), range(1, samples.shape[2] - 1)): stats[:, :, i, j] = callback(samples[:, :, i]) return stats @cached_property def demes(self) -> Dict[str, EmpiricalDistribution]: """ Get the distribution for each deme. :return: Dictionary of distributions. """ return {pop: EmpiricalSFSDistribution(self._samples.sum(axis=0)[i]) for i, pop in enumerate(self.pops)} @cached_property def mutation_configs(self) -> Dict[Tuple[float, ...], float]: """ Get a dictionary of all mutation configurations and their probabilities. :return: Dictionary of distributions. """ configs = defaultdict(lambda: 0) for config in self._mutations[0, 0]: configs[tuple(config)] += 1 / self._mutations.shape[2] return configs def get_mutation_config(self, config: Sequence[int]) -> float: """ Get the probability of observing the given mutational configuration. :param config: The mutational configuration. :return: The probability of observing the given mutational configuration. """ return self.mutation_configs[tuple(config)] def get_mutation_configs(self) -> Iterator[Tuple[Tuple[float, ...], float]]: """ An iterator over the probabilities of observing mutational configurations according to the infinite sites model. The order of the mutational configurations generated ascends in the number of mutations observed. :return: An iterator over the probabilities of observing mutational configurations. """ # reset generated mass self.generated_mass = 0 # iterate over number of mutations i = 0 while True: # iterate over configurations for config in self._sfs_dist._get_configs(self.n, i): p = self.get_mutation_config(config=config) self.generated_mass += p yield config, p # increase counter for number of mutations i += 1 class AbstractCoalescent(ABC): """ Abstract base class for coalescent distributions. This class provides probability distributions for the tree height, total branch length and site frequency spectrum. """ def __init__( self, n: int | Dict[str, int] | List[int] | LineageConfig, model: CoalescentModel = None, demography: Demography = None, loci: int | LocusConfig = 1, recombination_rate: float = None, end_time: float = None ): """ Create object. :param n: Number of lineages. Either a single integer if only one population, or a list of integers or a dictionary with population names as keys and number of lineages as values. Alternatively, a :class:`~phasegen.lineage.LineageConfig` object can be passed. :param model: Coalescent model. By default, the standard coalescent is used. :param loci: Number of loci or locus configuration. :param recombination_rate: Recombination rate. :param demography: Demography. :param end_time: Time when to end the computation. If ``None``, the end time is end time is taken to be the time of almost sure absorption. Note that unnecessarily large end times can lead to numerical errors. """ self._logger = logger.getChild(self.__class__.__name__) # set up default coalescent model if model is None: model = StandardCoalescent() if not isinstance(n, LineageConfig): #: Population configuration self.lineage_config: LineageConfig = LineageConfig(n) else: #: Population configuration self.lineage_config: LineageConfig = n # set up demography if demography is None: demography = Demography(pop_sizes={p: 1 for p in self.lineage_config.pop_names}) # set up locus configuration if isinstance(loci, int): #: Locus configuration self.locus_config: LocusConfig = LocusConfig( n=loci, recombination_rate=recombination_rate if recombination_rate is not None else 0 ) else: #: Locus configuration self.locus_config: LocusConfig = loci if recombination_rate is not None: self.locus_config.recombination_rate = recombination_rate # population names present in the population configuration but not in the demography initial_sizes = {p: {0: 1} for p in self.lineage_config.pop_names if p not in demography.pop_names} # add missing population sizes to demography if len(initial_sizes) > 0: demography.add_event( PopSizeChanges(initial_sizes) ) # warn if population names are present in the population configuration but not in the demography self._logger.warning( f"The following population names are present in the population configuration but not " f"in the demography: {list(initial_sizes.keys())}. " f"Adding these populations with population size of 1." ) # determine population names that are present in the demography but not in the population configuration unspecified_lineages = set(demography.pop_names) - set(self.lineage_config.pop_names) # warn if population names are present in the demography but not in the population configuration if len(unspecified_lineages) > 0: self._logger.warning( f"The following population names are present in the demography but not " f"in the population configuration: {list(unspecified_lineages)}. " f"Adding these populations with 0 lineages." ) self.lineage_config = LineageConfig(self.lineage_config.lineage_dict | {p: 0 for p in unspecified_lineages}) #: Coalescent model self.model: CoalescentModel = model #: Demography self.demography: Demography = demography #: End time self.end_time: float = end_time @property @abstractmethod def tree_height(self) -> DensityAwareDistribution: """ Tree height distribution. """ pass @property @abstractmethod def total_branch_length(self) -> MomentAwareDistribution: """ Total branch length distribution. """ pass @property @abstractmethod def sfs(self) -> MomentAwareDistribution: """ Unfolded site-frequency spectrum distribution. """ pass @property @abstractmethod def fsfs(self) -> MomentAwareDistribution: """ Folded site-frequency spectrum distribution. """ pass
[docs] class Coalescent(AbstractCoalescent, Serializable): """ Coalescent distribution. """
[docs] def __init__( self, n: int | Dict[str, int] | List[int] | LineageConfig, model: CoalescentModel = None, demography: Demography = None, loci: int | LocusConfig = 1, recombination_rate: float = None, start_time: float = 0, end_time: float = None, ): """ Create object. :param n: Number of lineages. Either a single integer if only one population, or a list of integers or dictionary with population names as keys and number of lineages as values for multiple populations. Alternatively, a :class:`~phasegen.lineage.LineageConfig` object can be passed. :param model: Coalescent model. Default is the standard coalescent. :param demography: Demography. :param loci: Number of loci or locus configuration. :param recombination_rate: Recombination rate. :param start_time: Time when to start accumulating moments. By default, this is 0. :param end_time: Time when to end the accumulating moments. If ``None``, the end time is taken to be the time of almost sure absorption. Note that unnecessarily long end times can lead to numerical errors. """ super().__init__( n=n, model=model, loci=loci, recombination_rate=recombination_rate, demography=demography, end_time=end_time ) #: Time when to start accumulating moments self.start_time: float = start_time
@cached_property def lineage_counting_state_space(self) -> LineageCountingStateSpace: """ The lineage-counting state space. """ return LineageCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.demography.get_epoch(0) ) @cached_property def block_counting_state_space(self) -> BlockCountingStateSpace: """ The block-counting state space. """ return BlockCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.demography.get_epoch(0) ) @cached_property def tree_height(self) -> TreeHeightDistribution: """ Tree height distribution. """ return TreeHeightDistribution( state_space=self.lineage_counting_state_space, demography=self.demography, start_time=self.start_time, end_time=self.end_time ) @cached_property def total_branch_length(self) -> PhaseTypeDistribution: """ Total branch length distribution. """ return PhaseTypeDistribution( reward=TotalBranchLengthReward(), tree_height=self.tree_height, state_space=self.lineage_counting_state_space, demography=self.demography ) @cached_property def sfs(self) -> UnfoldedSFSDistribution: """ Unfolded site-frequency spectrum distribution. """ return UnfoldedSFSDistribution( state_space=self.block_counting_state_space, tree_height=self.tree_height, demography=self.demography ) @cached_property def fsfs(self) -> FoldedSFSDistribution: """ Folded site-frequency spectrum distribution. """ return FoldedSFSDistribution( state_space=self.block_counting_state_space, tree_height=self.tree_height, demography=self.demography ) def _get_dist(self, k: int, rewards: Iterable[Reward] = None) -> PhaseTypeDistribution: """ Get the kth-order phase-type distribution with state space inferred from the rewards. The returned phase-type distribution is configured with the unit reward. :param k: Order of the moment. :param rewards: Sequence of k rewards. By default, tree height rewards are used. :return: Distribution. """ if rewards is None: rewards = [TreeHeightReward()] * k if Reward.support(LineageCountingStateSpace, rewards): state_space = self.lineage_counting_state_space else: state_space = self.block_counting_state_space return PhaseTypeDistribution( reward=UnitReward(), tree_height=self.tree_height, state_space=state_space, demography=self.demography )
[docs] @_make_hashable @cache def moment( self, k: int = 1, rewards: Sequence[Reward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> float: """ Get the kth (non-central) moment using the specified rewards and state space. :param k: The order of the moment :param rewards: Sequence of k rewards. By default, tree height rewards are used. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth moment """ return self._get_dist(k, rewards).moment( k=k, rewards=rewards, start_time=start_time, end_time=end_time, center=center, permute=permute )
def _sample( self, n_samples: int, rewards: Sequence[Reward] = None, record_visits: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Generate samples from the mean reward distribution by simulating trajectories. :param n_samples: Number of trajectories to simulate. :param rewards: Rewards to sample from. Default is the tree height reward. :param record_visits: Whether to record which states were visited during the sampling. :return: Array of sampled rewards of size (n_samples, len(rewards)), and optionally an array of probabilities of visiting each state. """ return self._get_dist(k=1, rewards=rewards)._sample( n_samples=n_samples, rewards=rewards, record_visits=record_visits ) def _raw_moment( self, k: int, rewards: Sequence[Reward] = None, start_time: float = None, end_time: float = None ) -> float: """ Get the kth raw moment using the specified rewards and state space. :param k: The order of the moment :param rewards: Sequence of k rewards. By default, tree height rewards are used. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :return: The kth raw moment """ return self.moment( k=k, rewards=rewards, start_time=start_time, end_time=end_time, center=False, permute=False )
[docs] def accumulate( self, k: int, end_times: Iterable[float], rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Accumulate moments at different times. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: Accumulation of moments. """ return self._get_dist(k, rewards).accumulate( k=k, end_times=end_times, rewards=rewards, center=center, permute=permute )
[docs] def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = False, label: str = None, title: str = None ) -> 'plt.Axes': """ Plot the accumulation of moments. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :param ax: Axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ self._get_dist(k, rewards).plot_accumulation( k=k, end_times=end_times, rewards=rewards, center=center, permute=permute, ax=ax, show=show, file=file, clear=clear, label=label, title=title )
[docs] def drop_cache(self): """ Drop state space cache. """ self.lineage_counting_state_space.drop_cache() self.block_counting_state_space.drop_cache()
def __setstate__(self, state: dict): """ Restore the state of the object from a serialized state. :param state: State. """ self.__dict__.update(state) def __getstate__(self) -> dict: """ Get the state of the object for serialization. :return: State. """ # create deep copy of object without causing infinite recursion other = copy.deepcopy(self.__dict__) if 'lineage_counting_state_space' in other: other['lineage_counting_state_space'].drop_cache() if 'block_counting_state_space' in other: other['block_counting_state_space'].drop_cache() return other
[docs] def to_json(self) -> str: """ Serialize to JSON. Drop cache before serializing. :return: JSON string. """ # copy object to avoid modifying the original other = copy.deepcopy(self) # drop cache other.drop_cache() return super(self.__class__, other).to_json()
[docs] def to_msprime( self, num_replicates: int = 10000, n_threads: int = 10, parallelize: bool = True, record_migration: bool = False, simulate_mutations: bool = False, mutation_rate: float = None, seed: int = None ) -> 'MsprimeCoalescent': """ Convert to msprime coalescent. :param num_replicates: Number of replicates. :param n_threads: Number of threads. :param parallelize: Whether to parallelize. :param record_migration: Whether to record migrations which is necessary to calculate statistics per deme. :param simulate_mutations: Whether to simulate mutations. :param mutation_rate: Mutation rate. :param seed: Random seed. :return: msprime coalescent. """ if self.start_time != 0: self._logger.warning("Non-zero start times are not supported by MsprimeCoalescent.") return MsprimeCoalescent( n=self.lineage_config, demography=self.demography, model=self.model, loci=self.locus_config, recombination_rate=self.locus_config.recombination_rate, mutation_rate=mutation_rate, end_time=self.end_time, num_replicates=num_replicates, n_threads=n_threads, parallelize=parallelize, record_migration=record_migration, simulate_mutations=simulate_mutations, seed=seed )
class MsprimeCoalescent(AbstractCoalescent): """ Empirical coalescent distribution based on `msprime` simulations. This is used for testing purposes. Note that the results are stochastic. """ def __init__( self, n: int | Dict[str, int] | List[int] | LineageConfig, demography: Demography = None, model: CoalescentModel = StandardCoalescent(), loci: int | LocusConfig = 1, recombination_rate: float = None, mutation_rate: float = None, end_time: float = None, num_replicates: int = 10000, n_threads: int = 100, parallelize: bool = True, record_migration: bool = False, simulate_mutations: bool = False, seed: int = None ): """ Simulate data using msprime. :param n: Number of Lineages. :param demography: Demography. :param model: Coalescent model. :param loci: Number of loci or locus configuration. :param recombination_rate: Recombination rate. :param mutation_rate: Mutation rate. :param end_time: Time when to end the simulation. :param num_replicates: Number of replicates. :param n_threads: Number of threads. :param parallelize: Whether to parallelize. :param record_migration: Whether to record migrations which is necessary to calculate statistics per deme. :param simulate_mutations: Whether to simulate mutations. :param seed: Random seed. """ super().__init__( n=n, model=model, loci=loci, recombination_rate=recombination_rate, demography=demography, end_time=end_time ) if mutation_rate is not None and not simulate_mutations: self._logger.warning("Mutation rate is set but mutations are not simulated.") #: Site frequency spectrum counts per locus, deme and replicate. self.sfs_lengths: np.ndarray | None = None #: Total branch lengths per locus, deme and replicate. self.total_branch_lengths: np.ndarray | None = None #: Tree heights per locus, deme and replicate. self.heights: np.ndarray | None = None #: Mutations per locus, deme and replicate. self.mutations: np.ndarray | None = None #: Number of replicates. self.num_replicates: int = num_replicates #: Mutation rate. self.mutation_rate: float = mutation_rate #: Number of threads. self.n_threads: int = n_threads #: Whether to parallelize computations. self.parallelize: bool = parallelize #: Whether to record migrations. self.record_migration: bool = record_migration #: Whether to simulate mutations. self.simulate_mutations: bool = simulate_mutations #: Random seed. self.seed: int = seed def get_coalescent_model(self) -> 'msprime.AncestryModel': """ Get the coalescent model. :return: msprime coalescent model. """ import msprime as ms if isinstance(self.model, StandardCoalescent): return ms.StandardCoalescent() if isinstance(self.model, BetaCoalescent): return ms.BetaCoalescent(alpha=self.model.alpha) if isinstance(self.model, DiracCoalescent): return ms.DiracCoalescent(psi=self.model.psi, c=self.model.c) @cache def simulate(self): """ Simulate data using msprime. """ # number of replicates for one thread num_replicates = self.num_replicates // self.n_threads samples = self.lineage_config.lineage_dict demography = self.demography.to_msprime() model = self.get_coalescent_model() end_time = self.end_time n_pops = self.demography.n_pops sample_size = self.lineage_config.n def simulate_batch(seed: Optional[int]) -> (np.ndarray, np.ndarray, np.ndarray): """ Simulate statistics. :param seed: Random seed. :return: Statistics. """ import msprime as ms import tskit # simulate trees g: Generator = ms.sim_ancestry( sequence_length=self.locus_config.n, recombination_rate=self.locus_config.recombination_rate, samples=samples, num_replicates=num_replicates, record_migrations=self.record_migration, demography=demography, model=model, ploidy=1, end_time=end_time, random_seed=seed ) # initialize variables heights = np.zeros((self.locus_config.n, n_pops, num_replicates), dtype=float) total_branch_lengths = np.zeros((self.locus_config.n, n_pops, num_replicates), dtype=float) sfs = np.zeros((self.locus_config.n, n_pops, num_replicates, sample_size + 1), dtype=float) mutations = np.zeros((self.locus_config.n, n_pops, num_replicates, sample_size + 1), dtype=int) # iterate over trees and compute statistics ts: tskit.TreeSequence for i, ts in enumerate(g): tree: tskit.Tree for j, tree in enumerate(self._expand_trees(ts)): # TODO record_migration only appears to work for relatively simple scenarios if self.record_migration: lineages = np.array(list(samples.values())) t_coal = ts.tables.nodes.time[sample_size:] node = sample_size - 1 t_migration = ts.migrations_time i_migration = 0 time = 0 # population state per leave pop_states = {n: tree.population(n) for n in range(sample_size)} # iterate over coalescence events for coal_time in t_coal: # iterate over migration events within this coalescence event while i_migration < len(t_migration) and time < t_migration[i_migration] <= coal_time: delta = t_migration[i_migration] - time # update statistics heights[j, :, i] += delta * lineages / sum(lineages) total_branch_lengths[j, :, i] += delta * lineages for n, pop in pop_states.items(): sfs[j, pop, i, tree.get_num_leaves(n)] += delta # update lineages with migrations lineages[ts.migrations_source[i_migration]] -= 1 lineages[ts.migrations_dest[i_migration]] += 1 pop_states[ts.migrations_node[i_migration]] = ts.migrations_dest[i_migration] i_migration += 1 time += delta # remaining time to next coalescence event delta = coal_time - time # update statistics heights[j, :, i] += delta * lineages / sum(lineages) total_branch_lengths[j, :, i] += delta * lineages for n, pop in pop_states.items(): sfs[j, pop, i, tree.get_num_leaves(n)] += delta # reduce by number of coalesced lineages lineages[tree.population(node + 1)] -= len(tree.get_children(node + 1)) - 1 # delete children from pop_states [pop_states.__delitem__(n) for n in tree.get_children(node + 1)] # add parent to pop_states pop_states[node + 1] = tree.population(node + 1) time += delta node += 1 else: heights[j, 0, i] = tree.time(tree.roots[0]) total_branch_lengths[j, 0, i] = tree.total_branch_length for node in tree.nodes(): t = tree.get_branch_length(node) n = tree.get_num_leaves(node) sfs[j, 0, i, n] += t # simulate mutations if specified if self.simulate_mutations: mts = ms.sim_mutations(ts, rate=self.mutation_rate, random_seed=seed) tree = next(mts.trees()) for node in mts.mutations_node: mutations[0, 0, i, tree.get_num_leaves(node)] += 1 return np.concatenate([[heights.T], [total_branch_lengths.T], sfs.T, mutations.T]) # parallelize and add up results res = np.hstack(parallelize( func=simulate_batch, data=[self.seed + i if self.seed is not None else None for i in range(self.n_threads)], parallelize=self.parallelize, batch_size=num_replicates, desc="Simulating trees" )) # store results self.heights = res[0].T self.total_branch_lengths = res[1].T self.sfs_lengths = res[2:sample_size + 3].T self.mutations = res[sample_size + 3:].T.astype(int) @staticmethod def _expand_trees(ts: 'tskit.TreeSequence') -> Iterator['tskit.Tree']: """ Expand tree sequence to `n` trees where `n` is the number of loci. :param ts: Tree sequence. :return: List of trees. """ for tree in ts.trees(): for _ in range(int(tree.length)): yield tree def _get_cached_times(self) -> np.ndarray: """ Get cached times. """ t_max = self.heights.sum(axis=1).max() return np.linspace(0, t_max, 100) def touch(self, **kwargs): """ Touch cached properties. :param kwargs: Additional keyword arguments. """ self.simulate() t = self._get_cached_times() self.tree_height.touch(t) self.total_tree_height.touch(t) self.total_branch_length.touch(t) self.sfs.touch(t) self.fsfs.touch(t) def drop(self): """ Drop simulated data. """ self.heights = None self.total_branch_lengths = None self.sfs_lengths = None self.mutations = None self.tree_height.drop() self.total_tree_height.drop() self.total_branch_length.drop() self.sfs.drop() self.fsfs.drop() # caused problems when serializing self.demography = None @cached_property def tree_height(self) -> EmpiricalPhaseTypeDistribution: """ Tree height distribution. """ self.simulate() return EmpiricalPhaseTypeDistribution( self.heights, pops=self.demography.pop_names, locus_agg=lambda x: x.max(axis=0) ) @cached_property def total_tree_height(self) -> EmpiricalPhaseTypeDistribution: """ Total tree height distribution. """ self.simulate() return EmpiricalPhaseTypeDistribution(self.heights, pops=self.demography.pop_names) @cached_property def total_branch_length(self) -> EmpiricalPhaseTypeDistribution: """ Total branch length distribution. """ self.simulate() return EmpiricalPhaseTypeDistribution(self.total_branch_lengths, pops=self.demography.pop_names) @cached_property def sfs(self) -> EmpiricalPhaseTypeSFSDistribution: """ Unfolded site-frequency spectrum distribution. """ self.simulate() return EmpiricalPhaseTypeSFSDistribution( branch_lengths=self.sfs_lengths, mutations=self.mutations.T[1:-1].T, pops=self.demography.pop_names, sfs_dist=UnfoldedSFSDistribution ) @cached_property def fsfs(self) -> EmpiricalPhaseTypeSFSDistribution: """ Folded site-frequency spectrum distribution. """ self.simulate() mid = (self.lineage_config.n + 1) // 2 # fold SFS branch lengths lengths = self.sfs_lengths.copy().T lengths[:mid] += lengths[-mid:][::-1] lengths[-mid:] = 0 # fold SFS mutations mutations = self.mutations.copy().T mutations[:mid] += mutations[-mid:][::-1] mutations = mutations[1:self.lineage_config.n // 2 + 1] return EmpiricalPhaseTypeSFSDistribution( branch_lengths=lengths.T, mutations=mutations.T, pops=self.demography.pop_names, sfs_dist=FoldedSFSDistribution ) def to_phasegen(self) -> Coalescent: """ Convert to native phasegen coalescent. :return: phasegen coalescent. """ return Coalescent( n=self.lineage_config, model=self.model, demography=self.demography, loci=self.locus_config, recombination_rate=self.locus_config.recombination_rate, end_time=self.end_time )