Source code for phasegen.distributions

"""
Probability distributions.
"""

import copy
import functools
import itertools
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping
from functools import cached_property, cache
from math import factorial
from typing import Generator, List, Callable, Tuple, Dict, Collection, Iterable, Iterator, Optional, Sequence, Set, \
    Type, Union

import numpy as np
import scipy.sparse as sp
from scipy.ndimage import gaussian_filter1d
from tqdm import tqdm

from .coalescent_models import StandardCoalescent, CoalescentModel, BetaCoalescent, DiracCoalescent
from .demography import Demography, PopSizeChanges
from .expm import Backend
from .lineage import LineageConfig
from .locus import LocusConfig
from .rewards import Reward, TreeHeightReward, TotalBranchLengthReward, UnfoldedSFSReward, DemeReward, UnitReward, \
    LocusReward, CombinedReward, FoldedSFSReward, SFSReward, CustomReward, JointSFSReward, TwoLocusSFSReward
from .serialization import Serializable
from .settings import Settings
from .spectrum import SFS, SFS2, JointSFS, TwoLocusSFS
from .state_space import BlockCountingStateSpace, LineageCountingStateSpace, StateSpace, JointBlockCountingStateSpace, \
    TwoLocusBlockCountingStateSpace
from .utils import parallelize, multiset_permutations

expm = Backend.expm

logger = logging.getLogger('phasegen')


def _make_hashable(func: Callable) -> Callable:
    """
    Decorator that makes a function hashable by converting non-hashable arguments to hashable ones.
    """

    @functools.wraps(func)
    def wrapper(self, *args: tuple, **kwargs: dict):
        """
        Wrapper function.

        :param self: Self.
        :return: The result of the function.
        """
        args = list(args)

        for i, arg in enumerate(args):
            if isinstance(arg, (list, np.ndarray)):
                args[i] = tuple(arg)

        for key, value in kwargs.items():
            if isinstance(value, (list, np.ndarray)):
                kwargs[key] = tuple(value)

        return func(self, *args, **kwargs)

    return wrapper


class ProbabilityDistribution(ABC):
    """
    Abstract base class for probability distributions for which moments can be calculated.
    """

    def __init__(self):
        """
        Create object.
        """
        #: Logger
        self._logger = logger.getChild(self.__class__.__name__)

    def touch(self, **kwargs: dict):
        """
        Touch all cached properties.

        :param kwargs: Additional keyword arguments.
        """
        for cls in self.__class__.__mro__:
            for attr, value in cls.__dict__.items():
                if isinstance(value, cached_property):
                    getattr(self, attr)


class MomentAwareDistribution(ProbabilityDistribution, ABC):
    """
    Abstract base class for probability distributions for which moments can be calculated.
    """

    @abstractmethod
    @cached_property
    def mean(self) -> float:
        """
        First moment / mean.
        """
        pass

    @abstractmethod
    @cached_property
    def var(self) -> float:
        """
        Second central moment / variance.
        """
        pass

    @abstractmethod
    @cached_property
    def m2(self) -> float:
        """
        Second (non-central) moment.
        """
        pass


class MarginalDistributions(Mapping, ABC):
    """
    Base class for marginal distributions.
    """

    @abstractmethod
    @cached_property
    def cov(self) -> np.ndarray:
        """
        Covariance matrix.
        """
        pass

    @abstractmethod
    @cached_property
    def corr(self) -> np.ndarray:
        """
        Correlation matrix.
        """
        pass

    @abstractmethod
    def get_cov(self, d1, d2) -> float:
        """
        Get the covariance between two marginal distributions.

        :param d1: The index of the first marginal distribution.
        :param d2: The index of the second marginal distribution.
        :return: covariance
        """
        pass

    @abstractmethod
    def get_corr(self, d1, d2) -> float:
        """
        Get the correlation coefficient between two marginal distributions.

        :param d1: The index of the first marginal distribution.
        :param d2: The index of the second marginal distribution.
        :return: correlation coefficient
        """
        pass


[docs] class MarginalLocusDistributions(MarginalDistributions): """ Marginal locus distributions. """
[docs] def __init__(self, dist: 'PhaseTypeDistribution'): """ Initialize the distributions. :param dist: The distribution. """ self.dist = dist
def __getitem__(self, item): """ Get the distribution for the given locus. :param item: Deme name. :return: Distribution. """ return self.loci[item] def __iter__(self) -> Iterator: """ Iterate over distributions. :return: Iterator. """ return iter(self.loci) def __len__(self) -> int: """ Get the number of distributions. :return: Number of distributions. """ return len(self.loci) @cached_property def loci(self) -> 'MarginalLocusDistributions': """ Distributions marginalized over loci. """ # get class of distribution but use PhaseTypeDistribution # if this is a TreeHeightDistribution as TreeHeightDistribution # only works with default rewards cls = self.dist.__class__ if not isinstance(self.dist, TreeHeightDistribution) else PhaseTypeDistribution loci = {} for locus in range(self.dist.locus_config.n): loci[locus] = cls( state_space=self.dist.state_space, tree_height=self.dist.tree_height, demography=self.dist.demography, reward=CombinedReward([self.dist.reward, LocusReward(locus)]) ) return loci
[docs] def get_cov(self, locus1: int, locus2: int) -> float: """ Get the covariance between two loci. :param locus1: The first locus. :param locus2: The second locus. :return: The covariance. """ locus1 = int(locus1) locus2 = int(locus2) if locus1 not in range(self.dist.locus_config.n) or locus2 not in range(self.dist.locus_config.n): raise ValueError(f"Locus {locus1} or {locus2} does not exist.") return self.dist.moment( k=2, rewards=( CombinedReward([self.dist.reward, LocusReward(locus1)]), CombinedReward([self.dist.reward, LocusReward(locus2)]) ), center=True )
@cached_property def cov(self) -> np.ndarray: """ Covariance matrix across loci. """ n_loci = self.dist.locus_config.n return np.array([[self.get_cov(i, j) for i in range(n_loci)] for j in range(n_loci)])
[docs] def get_corr(self, locus1: int, locus2: int) -> float: """ Get the correlation coefficient between two loci. :param locus1: The first locus. :param locus2: The second locus. :return: The correlation coefficient. """ locus1 = int(locus1) locus2 = int(locus2) return self.get_cov(locus1, locus2) / (self.loci[locus1].std * self.loci[locus2].std)
@cached_property def corr(self) -> np.ndarray: """ Correlation matrix across loci. """ n_loci = self.dist.locus_config.n return np.array([[self.get_corr(i, j) for i in range(n_loci)] for j in range(n_loci)])
[docs] class MarginalDemeDistributions(MarginalDistributions): """ Marginal deme distributions. """
[docs] def __init__(self, dist: 'PhaseTypeDistribution'): """ Initialize the distributions. :param dist: The distribution. """ self.dist = dist
def __getitem__(self, item): """ Get the distribution for the given deme. :param item: Deme name. :return: Distribution. """ return self.demes[item] def __iter__(self) -> Iterator: """ Iterate over distributions. :return: Iterator. """ return iter(self.demes) def __len__(self) -> int: """ Get the number of distributions. :return: Number of distributions. """ return len(self.demes) @cached_property def demes(self) -> 'MarginalDemeDistributions': """ Distributions marginalized over demes. """ # get class of distribution but use PhaseTypeDistribution # if this is a TreeHeightDistribution as TreeHeightDistribution # only works with default rewards cls = self.dist.__class__ if not isinstance(self.dist, TreeHeightDistribution) else PhaseTypeDistribution demes = {} for pop in self.dist.lineage_config.pop_names: demes[pop] = cls( state_space=self.dist.state_space, tree_height=self.dist.tree_height, demography=self.dist.demography, reward=CombinedReward([self.dist.reward, DemeReward(pop)]) ) return demes
[docs] def get_cov(self, pop1: str, pop2: str) -> float: """ Get the covariance between two demes. :param pop1: The first deme. :param pop2: The second deme. :return: The covariance. """ if pop1 not in self.dist.lineage_config.pop_names or pop2 not in self.dist.lineage_config.pop_names: raise ValueError(f"Population {pop1} or {pop2} does not exist.") return self.dist.moment( k=2, rewards=( CombinedReward([self.dist.reward, DemeReward(pop1)]), CombinedReward([self.dist.reward, DemeReward(pop2)]) ), center=True )
@cached_property def cov(self) -> np.ndarray: """ Covariance matrix across demes. """ pops = self.dist.lineage_config.pop_names return np.array([[self.get_cov(p1, p2) for p1 in pops] for p2 in pops])
[docs] def get_corr(self, pop1: str, pop2: str) -> float: """ Get the correlation coefficient between two demes. :param pop1: The first deme. :param pop2: The second deme. :return: The correlation coefficient. """ return self.get_cov(pop1, pop2) / (self.demes[pop1].std * self.demes[pop2].std)
@cached_property def corr(self) -> np.ndarray: """ Correlation matrix across demes. """ pops = self.dist.lineage_config.pop_names return np.array([[self.get_corr(p1, p2) for p1 in pops] for p2 in pops])
class DensityAwareDistribution(MomentAwareDistribution, ABC): """ Abstract base class for probability distributions for which moments and densities can be calculated. """ @abstractmethod def cdf(self, t: float | Sequence[float]) -> float | np.ndarray: """ Cumulative distribution function. :param t: Value or values to evaluate the CDF at. :return: CDF. """ pass @abstractmethod def quantile(self, q: float) -> float: """ Get the qth quantile. """ pass @abstractmethod def pdf(self, t: float | Sequence[float], **kwargs: dict) -> float | np.ndarray: """ Density function. :param t: Value or values to evaluate the density function at. :param kwargs: Additional keyword arguments. :return: Density. """ pass def plot_cdf( self, ax: 'plt.Axes' = None, t: np.ndarray = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = 'Tree height CDF' ) -> 'plt.Axes': """ Plot cumulative distribution function. :param ax: Axes to plot on. :param t: Values to evaluate the CDF at. By default, 200 evenly spaced values between 0 and the 99th percentile. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ from .visualization import Visualization if t is None: t = np.linspace(0, self.quantile(0.99), 200) return Visualization.plot( ax=ax, x=t, y=self.cdf(t), xlabel='t', ylabel='F(t)', label=label, file=file, show=show, clear=clear, title=title ) def plot_pdf( self, ax: 'plt.Axes' = None, t: np.ndarray = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = 'Tree height PDF', dx: float = None ) -> 'plt.Axes': """ Plot density function. :param ax: The axes to plot on. :param t: Values to evaluate the density function at. By default, 200 evenly spaced values between 0 and the 99th percentile. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :param dx: Step size for numerical differentiation. By default, the 99th percentile divided by 1e10. :return: Axes. """ from .visualization import Visualization if dx is None: dx = self.quantile(0.99) / 1e10 if t is None: t = np.linspace(0, self.quantile(0.99), 200) return Visualization.plot( ax=ax, x=t, y=self.pdf(t, dx=dx), xlabel='t', ylabel='f(t)', label=label, file=file, show=show, clear=clear, title=title )
[docs] class PhaseTypeDistribution(MomentAwareDistribution): """ Phase-type distribution for a piecewise time-homogeneous process. """
[docs] def __init__( self, state_space: StateSpace, tree_height: 'TreeHeightDistribution', demography: Demography = None, reward: Reward = None ): """ Initialize the distribution. :param state_space: The state space. :param tree_height: The tree height distribution. :param demography: The demography. :param reward: The reward. By default, the tree height reward. """ if demography is None: demography = Demography() if reward is None: reward = TreeHeightReward() super().__init__() #: Population configuration self.lineage_config: LineageConfig = state_space.lineage_config #: Locus configuration self.locus_config: LocusConfig = state_space.locus_config #: Reward self.reward: Reward = reward #: State space self.state_space: StateSpace = state_space #: Demography self.demography: Demography = demography #: Tree height distribution self.tree_height: TreeHeightDistribution = tree_height
@staticmethod def _get_van_loan_matrix(R: List[np.ndarray], S: np.ndarray, k: int = 1) -> np.ndarray: """ Get the block matrix for the given reward matrices and transition matrix. :param R: List of length k of reward matrices :param S: Intensity matrix :param k: The order of the moment. :return: Van Loan matrix which is a block matrix of size (k + 1) * (k + 1) """ # matrix of zeros O = np.zeros_like(S) # create compound matrix return np.block([[S if i == j else R[i] if i == j - 1 else O for j in range(k + 1)] for i in range(k + 1)]) @staticmethod def _get_van_loan_matrix_sparse(R: List[np.ndarray], S: 'sp.spmatrix', k: int = 1) -> 'sp.spmatrix': """ Sparse, block-bidiagonal Van Loan matrix: the (sparse) intensity matrix ``S`` on the diagonal and the (diagonal) reward matrices on the super-diagonal. Built directly as a sparse matrix to avoid materializing the dense ``(k + 1) * n`` block matrix. :param R: List of length k of reward vectors (the diagonals of the reward matrices). :param S: Sparse intensity matrix. :param k: The order of the moment. :return: Sparse Van Loan matrix of size ``(k + 1) * (k + 1)`` blocks. """ blocks = [[None] * (k + 1) for _ in range(k + 1)] for i in range(k + 1): blocks[i][i] = S if i < k: blocks[i][i + 1] = sp.diags(R[i]) return sp.bmat(blocks, format='csr') @cached_property def mean(self) -> float | SFS: """ First moment / mean. """ return self.moment(k=1) @cached_property def var(self) -> float | SFS: """ Second central moment / variance. """ return self.moment(k=2, center=True) @cached_property def std(self) -> float | SFS: """ Standard deviation. """ return self.var ** 0.5 @cached_property def m2(self) -> float | SFS: """ Second (non-central) moment. """ return self.moment(k=2, center=False) @cached_property def demes(self) -> MarginalDemeDistributions: """ Marginal distributions over each deme. """ return MarginalDemeDistributions(self) @cached_property def loci(self) -> MarginalLocusDistributions: """ Marginal distributions over each locus. """ return MarginalLocusDistributions(self)
[docs] @_make_hashable @cache def moment( self, k: int, rewards: Sequence[Reward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> float: """ Get the kth (non-central) (cross-)moment. :param k: The order of the moment. :param rewards: Iterable of k rewards. By default, the reward of the underlying distribution. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth moment """ if start_time is None: start_time = self.tree_height.start_time if end_time is None: end_time = self.tree_height.t_max if start_time > 0: m_start, m_end = PhaseTypeDistribution.accumulate( self, k=k, end_times=[start_time, end_time], rewards=rewards, center=center, permute=permute ) m = float(m_end - m_start) else: m = float(PhaseTypeDistribution.accumulate( self, k=k, end_times=[end_time], rewards=rewards, center=center, permute=permute )[0]) if np.isnan(m): raise ValueError( "NaN value encountered when computing moment. " "This is likely due to an ill-conditioned rate matrix." ) return m
@staticmethod def _get_regularization_factor(S: np.ndarray) -> float: """ Get the regularization factor for the given intensity matrix. We multiply the intensity matrix by this factor to improve numerical stability when computing the matrix exponential of the Van Loan matrix. If regularization is disabled, this factor is 1. :param S: Intensity matrix. :return: Regularization factor. """ if not Settings.regularize: return 1.0 # obtain positive rates rates = S[S > 0] # rewards in the Van Loan matrix are of order 1 return 10 ** - np.log10(rates).mean() def _check_numerical_stability(self, S: np.ndarray, epoch: int): """ Warn about potential numerical instability with very small or very large rates. :param S: (Regularized) intensity matrix. :param epoch: Epoch number. """ rates = S[S > 0] if rates.min() / rates.max() < 1e-10: self._logger.warning( f"Intensity matrix in epoch {epoch} contains rates that differ by more than 10 orders of magnitude: " f"min: {rates.min()}, max: {rates.max()}. " f"This may lead to numerical instability, despite matrix regularization." )
[docs] def accumulate( self, k: int, end_times: Iterable[float], rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Evaluate the kth moment at different end times. :param k: The order of the moment. :param end_times: List of ends times or end time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The moment accumulated at the specified times or time. """ k = int(k) if rewards is None: rewards = [self.reward] * k if k != len(rewards): raise ValueError(f"Number of specified rewards for moment of order {k} must be {k}.") if k == 0: return np.ones_like(list(end_times)) # center moments around the mean if center and k > 1: components = [] # first order moments means = [ PhaseTypeDistribution.accumulate( self, k=1, rewards=(rewards[i],), end_times=end_times ) for i in range(k) ] for i in range(k + 1): # iterate over all possible subsets of rewards of size i for indices in itertools.combinations(range(k), i): # joint moment mu_i = PhaseTypeDistribution.accumulate( self, k=i, rewards=tuple(rewards[j] for j in indices), end_times=end_times, center=False, permute=permute ) # product of means of remaining rewards mu1 = np.prod([means[j] for j in range(k) if j not in indices], axis=0) components += [(-1) ** (k - i) * mu_i * mu1] return np.sum(components, axis=0) if permute: # get all possible permutations of rewards permutations = list(itertools.permutations(rewards)) # compute average over all permutations return np.sum([self._accumulate(k, tuple(end_times), r) for r in permutations], axis=0) / len(permutations) return self._accumulate(k, tuple(end_times), rewards)
@_make_hashable @cache def _accumulate_flattened( self, k: int, end_times: Sequence[float], rewards: Sequence[Reward] = None ) -> np.ndarray: """ Evaluate the kth (non-central) moment at different end times using the lineage counting state space. :param k: The order of the moment. :param end_times: Sequence of end times or end time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :return: The moment accumulated at the specified times or time. :raises ValueError: If the state space is not a BlockCountingStateSpace, or if k is not 1, or if there are multiple populations or loci or if the coalescent model is not the standard coalescent. """ if not isinstance(self.state_space, BlockCountingStateSpace): raise ValueError("Flattened accumulation is only supported for BlockCountingStateSpace.") if k != 1: raise ValueError("Flattened accumulation is only supported for k = 1.") if self.lineage_config.n_pops != 1 or self.locus_config.n != 1: raise ValueError("Flattened accumulation is only supported for a single population and a single locus.") if not isinstance(self.state_space.model, StandardCoalescent): raise ValueError("Flattened accumulation is only supported for standard coalescent.") reward = rewards[0] if rewards else self.reward r = reward._get(self.state_space) probs = self.state_space._state_probs # sum up weights for each state based on the number of lineages n = self.lineage_config.n weights = np.zeros(n) for i, s in enumerate(self.state_space.states): weights[n - s.lineages.sum()] += probs[i] * r[i] # Create a custom reward that returns the weights. weighted_reward = CustomReward(lambda _: weights) return self.tree_height._accumulate(k=k, end_times=end_times, rewards=(weighted_reward,)) @_make_hashable @cache def _accumulate( self, k: int, end_times: Sequence[float], rewards: Sequence[Reward] = None ) -> np.ndarray: """ Evaluate the kth (non-central) moment at different end times. :param k: The order of the moment. :param end_times: Sequence of ends times or end time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :return: The moment accumulated at the specified times or time. """ if ( Settings.flatten_block_counting and k == 1 and isinstance(self.state_space, BlockCountingStateSpace) and isinstance(self.state_space.model, StandardCoalescent) and self.lineage_config.n_pops == 1 and self.locus_config.n == 1 ): return self._accumulate_flattened(k, end_times, rewards) end_times = np.array(end_times) # check for negative values if np.any(end_times < 0): raise ValueError("Negative end times are not allowed.") # use default reward if not specified if rewards is None: rewards = (self.reward,) * k else: if len(rewards) != k: raise ValueError(f"Number of rewards must be {k}.") # sort array in ascending order but keep track of original indices t_sorted: Collection[float] = np.sort(end_times) epochs = enumerate(self.demography.epochs) i_epoch, epoch = next(epochs) # get state space for the first epoch self.state_space.update_epoch(epoch) # number of states n_states = self.state_space.k # for large (sparse) state spaces, compute the moment via the action of the matrix exponential on a vector # (threading through the epochs) instead of forming the dense Van Loan propagator if (k + 1) * n_states >= Settings.expm_action_min_dim: return self._accumulate_action(k, end_times, t_sorted, rewards) # initialize block matrix holding (rewarded) moments Q = np.eye(n_states * (k + 1)) u_prev = 0 # initialize probabilities moments = np.zeros_like(t_sorted, dtype=float) # regularization parameter lamb = self._get_regularization_factor(self.state_space.S) # regularized intensity matrix S = self.state_space.S * lamb # check numerical stability self._check_numerical_stability(S, 0) # get reward matrix R = [np.diag(r._get(state_space=self.state_space)) for r in rewards] # get Van Loan matrix V = self._get_van_loan_matrix(S=S, R=R, k=k) # iterate through sorted values for i, u in enumerate(t_sorted): # iterate over epochs between u_prev and u while u > epoch.end_time: # update transition matrix with remaining time in current epoch Q @= expm(V * (epoch.end_time - u_prev) / lamb) # fetch and update for next epoch u_prev = epoch.end_time i_epoch, epoch = next(epochs) self.state_space.update_epoch(epoch) # compute Van Loan matrix for next epoch using regularized intensity matrix S = self.state_space.S * lamb self._check_numerical_stability(S, 0) V = self._get_van_loan_matrix(S=S, R=R, k=k) # update with remaining time in current epoch Q @= expm(V * (u - u_prev) / lamb) alpha = self.state_space.alpha e = self.state_space.e moments[i] = factorial(k) * lamb ** k * alpha @ Q[:n_states, -n_states:] @ e u_prev = u # sort probabilities back to original order moments = moments[np.argsort(end_times)] if np.isnan(moments).any(): self._logger.warning( "NaN values encountered when computing moments. " f"Epoch: {i_epoch} at time: {epoch.start_time}. " "This is likely due to an ill-conditioned rate matrix." ) return moments def _accumulate_action( self, k: int, end_times: np.ndarray, t_sorted: np.ndarray, rewards: Sequence[Reward] ) -> np.ndarray: """ Sparse-action variant of :meth:`_accumulate` for large state spaces. Instead of forming the dense Van Loan propagator ``Q = prod_i exp(V_i tau_i)`` and reading off ``alpha @ Q[:n, -n:] @ e``, this threads the vector ``w = alpha_ext`` through the epochs via the action of the matrix exponential on the (sparse) Van Loan matrix (``scipy.sparse.linalg.expm_multiply``), reading off ``w @ e_ext`` at each end time. This is exact (a product applied to a vector is a sequence of matrix-vector actions) and exploits the rate matrix sparsity. :param k: The order of the moment. :param end_times: The (unsorted) end times, used to restore the original order. :param t_sorted: The sorted end times. :param rewards: Sequence of k rewards. :return: The moment accumulated at the specified times. """ epochs = enumerate(self.demography.epochs) i_epoch, epoch = next(epochs) self.state_space.update_epoch(epoch) n = self.state_space.k lamb = self._get_regularization_factor(self.state_space.S) def transposed_van_loan() -> 'sp.spmatrix': """Transposed sparse Van Loan matrix for the current epoch (transposed for the left vector action).""" S = self.state_space.S * lamb self._check_numerical_stability(S, i_epoch) r_vecs = [np.asarray(r._get(state_space=self.state_space), dtype=float) for r in rewards] return self._get_van_loan_matrix_sparse(R=r_vecs, S=sp.csr_matrix(S), k=k).T.tocsr() Vt = transposed_van_loan() # w = alpha_ext (alpha in the first block); e_ext = e in the last block, so w @ Q @ e_ext = alpha @ Q[:n,-n:] @ e w = np.zeros((k + 1) * n) w[:n] = self.state_space.alpha e_ext = np.zeros((k + 1) * n) e_ext[-n:] = self.state_space.e moments = np.zeros_like(t_sorted, dtype=float) u_prev = 0.0 for i, u in enumerate(t_sorted): # advance through whole epochs between u_prev and u while u > epoch.end_time: w = Backend.expm_multiply(Vt * ((epoch.end_time - u_prev) / lamb), w) u_prev = epoch.end_time i_epoch, epoch = next(epochs) self.state_space.update_epoch(epoch) Vt = transposed_van_loan() # remaining time in the current epoch w = Backend.expm_multiply(Vt * ((u - u_prev) / lamb), w) moments[i] = factorial(k) * lamb ** k * float(w @ e_ext) u_prev = u moments = moments[np.argsort(end_times)] if np.isnan(moments).any(): self._logger.warning( "NaN values encountered when computing moments via the matrix-exponential action. " f"Epoch: {i_epoch} at time: {epoch.start_time}. " "This is likely due to an ill-conditioned rate matrix." ) return moments def _sample( self, n_samples: int, rewards: Sequence[Reward] = None, record_visits: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Generate samples from the mean reward distribution by simulating trajectories. :param n_samples: Number of trajectories to simulate. :param rewards: Rewards to sample from. Default is the tree height reward. :param record_visits: Whether to record which states were visited during the sampling. :return: Array of sampled rewards of size (n_samples, len(rewards)), and optionally an array of probabilities of visiting each state. """ if rewards is None: rewards = [self.reward] n_rewards = len(rewards) samples = np.zeros((n_samples, n_rewards)) absorbing = np.array([s.is_absorbing() for s in self.state_space.states]) alpha = self.state_space.alpha states_visited = np.zeros_like(alpha) R = np.array([r._get(self.state_space) for r in rewards]) # iterate over samples for i in tqdm(range(n_samples), disable=not Settings.use_pbar): mass = np.zeros(n_rewards) t = 0 rate = 0 state = np.random.choice(len(alpha), p=alpha) epochs = self.demography.epochs traj_probs = [] if record_visits else None try: # find first non-zero rate epoch while rate == 0: epoch = next(epochs) self.state_space.update_epoch(epoch) rate = -self.state_space.S[state, state] # sample next time step dt = np.random.exponential(1 / rate) # iterate over transitions while True: # iterate over epochs while t + dt >= epoch.end_time: # reward until epoch boundary mass += R[:, state] * (epoch.end_time - t) dt -= (epoch.end_time - t) t = epoch.end_time # advance epoch epoch = next(epochs) self.state_space.update_epoch(epoch) new_rate = -self.state_space.S[state, state] if new_rate == 0: t = epoch.end_time continue # rescale remaining time dt *= rate / new_rate rate = new_rate # step completes in current epoch mass += R[:, state] * dt t += dt # sample next state probs = self.state_space.S[state].copy() probs[state] = 0 state = np.random.choice(len(probs), p=probs / rate) states_visited[state] += 1 if absorbing[state]: raise StopIteration rate = -self.state_space.S[state, state] # if rate is zero, we skip to the next epoch if rate == 0: t = epoch.end_time continue # sample next time step dt = np.random.exponential(1 / rate) except StopIteration: pass samples[i] = mass # normalize states visited states_visited /= n_samples return (samples, states_visited) if record_visits else samples
[docs] def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = None ) -> 'plt.Axes': """ Plot accumulation of (non-central) moments at different times. .. note:: This is different from a CDF, as it shows the accumulation of moments rather than the probability of having reached absorption at a certain time. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :param ax: The axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ k = int(k) from .visualization import Visualization if end_times is None: end_times = np.linspace(0, self.tree_height.quantile(0.99), 200) if rewards is None: rewards = (self.reward,) * k if title is None: title = f"Moment accumulation ({', '.join(r.__class__.__name__.replace('Reward', '') for r in rewards)})" y = self.accumulate(k, end_times, rewards, center, permute) Visualization.plot( ax=ax, x=end_times, y=y, xlabel='t', ylabel='moment', label=label, file=file, show=show, clear=clear, title=title )
[docs] class TreeHeightDistribution(PhaseTypeDistribution, DensityAwareDistribution): """ Phase-type distribution for a piecewise time-homogeneous process that allows the computation of the density function. This is currently only possible with default rewards. """ #: Maximum number of epochs to consider when determining time to almost sure absorption. max_epochs: int = 10000 #: Maximum number of time we double the end time when determining time to almost sure absorption. max_iter: int = 20 #: Probability of almost sure absorption. p_absorption: float = 1 - 1e-15
[docs] def __init__( self, state_space: LineageCountingStateSpace, demography: Demography = None, start_time: float = 0, end_time: float = None ): """ Initialize the distribution. :param state_space: The state space. :param demography: The demography. :param start_time: Time when to start accumulating moments. :param end_time: Time when to end accumulation of moments. By default, the time until almost sure absorption. """ if start_time < 0: raise ValueError("Start time must be greater than or equal to 0.") if end_time is not None and end_time < 0: raise ValueError("End time must be greater than or equal to 0.") if end_time is not None and end_time < start_time: raise ValueError("End time must be greater than equal start time.") super().__init__( state_space=state_space, tree_height=self, demography=demography, reward=TreeHeightReward() ) #: State space self.state_space: LineageCountingStateSpace = state_space #: Start time self.start_time: float = start_time #: End time self.end_time: float | None = end_time
[docs] def cdf(self, t: float | Sequence[float]) -> float | np.ndarray: """ Cumulative distribution function. :param t: Value or values to evaluate the CDF at. :return: Cumulative probability :raises NotImplementedError: if rewards are not default """ # raise error if rewards are not default if not isinstance(self.reward, TreeHeightReward): raise NotImplementedError("PDF not implemented for non-default rewards.") # assume scalar if not array if not isinstance(t, Iterable): return self.cdf(np.array([t]))[0] # check for negative values if np.any(t < 0): raise ValueError("Negative values are not allowed.") # sort array in ascending order but keep track of original indices t_sorted: Collection[float] = np.sort(t).astype(float) epochs = enumerate(self.demography.epochs) i_epoch, epoch = next(epochs) # get the transition matrix for the first epoch self.state_space.update_epoch(epoch) # initialize transition matrix T = np.eye(self.state_space.k) u_prev = 0 # initialize probabilities probs = np.zeros_like(t_sorted) # take reward vector as exit vector e = self.reward._get(self.state_space) # iterate through sorted values for i, u in enumerate(t_sorted): # iterate over epochs between u_prev and u while u > epoch.end_time: self._check_numerical_stability(self.state_space.S, i_epoch) # update transition matrix with remaining time in current epoch T @= expm(self.state_space.S * (epoch.end_time - u_prev)) # fetch and update for next epoch u_prev = epoch.end_time i_epoch, epoch = next(epochs) self.state_space.update_epoch(epoch) self._check_numerical_stability(self.state_space.S, i_epoch) # update transition matrix with remaining time in current epoch T @= expm(self.state_space.S * (u - u_prev)) probs[i] = 1 - self.state_space.alpha @ T @ e u_prev = u # sort probabilities back to original order probs = probs[np.argsort(t)] if np.isnan(probs).any(): self._logger.critical( "NaN values in CDF. This is likely due to an ill-conditioned rate matrix." ) return probs
def _update( self, u: float, u_prev: float, T: np.ndarray, epoch: 'Epoch' ) -> Tuple[float, np.ndarray, 'Epoch']: """ Update transition matrix and time. :param u: Time to update to. :param u_prev: Previous time. :param T: Transition matrix. :param epoch: Current epoch. :return: Updated time, transition matrix, and epoch. """ self.state_space.update_epoch(epoch) while u > epoch.end_time: # update transition matrix with remaining time in current epoch tau = epoch.end_time - u_prev T = T @ expm(self.state_space.S * tau) u_prev = epoch.end_time # fetch and update for next epoch epoch = self.demography.get_epoch(epoch.end_time) self.state_space.update_epoch(epoch) else: # update transition matrix T = T @ expm(self.state_space.S * (u - u_prev)) return u, T, epoch @cached_property def _e(self) -> np.ndarray: """ Exit vector. """ return self.reward._get(self.state_space) def _cum(self, T: np.ndarray) -> float: """ Get cumulative probability for given transition matrix. :param T: Transition matrix. :return: Cumulative probability. """ return float(1 - self.state_space.alpha @ T @ self._e)
[docs] @cache def quantile( self, q: float, expansion_factor: float = 2, precision: float = 1e-5, max_iter: int = 1000 ): """ Find the specified quantile of a CDF using an adaptive bisection method. :param q: The desired quantile (between 0 and 1). :param expansion_factor: Factor by which to multiply the upper bound that does not yet contain the quantile. :param precision: Precision for quantile, i.e. ``F(b) - F(a) < precision``. :param max_iter: Maximum number of iterations. :return: The quantile. """ if q < 0 or q > 1: raise ValueError("Specified quantile must be between 0 and 1.") if expansion_factor <= 1: raise ValueError("Expansion factor must be greater than 1.") # initialize bounds a, b = 0, 1 T_a = np.eye(self.state_space.k) epoch_a, epoch_b = self.demography.get_epoch(0), self.demography.get_epoch(0) b, T_b, epoch_b = self._update(b, a, T_a, epoch_b) i = 0 # expand lower bound until it contains the quantile while self._cum(T_b) < q and i < max_iter: b, T_b, epoch_b = self._update(b * expansion_factor, b, T_b, epoch_b) i += 1 # use bisection method within the determined bounds while self._cum(T_b) - self._cum(T_a) > precision and i < max_iter: m, T_m, epoch_m = self._update((a + b) / 2, a, T_a, epoch_a) if self._cum(T_m) < q: a, T_a, epoch_a = m, T_m, epoch_m else: b, T_b, epoch_b = m, T_m, epoch_m i += 1 # warn if maximum number of iterations reached if i - 1 == max_iter: raise RuntimeError("Maximum number of iterations reached when determining quantile.") return (a + b) / 2
[docs] def pdf(self, t: float | Sequence[float], dx: float = None) -> float | np.ndarray: """ Density function. We use numerical differentiation of the CDF to calculate the density. This provides good results as the CDF is exact and continuous. :param t: Value or values to evaluate the density function at. :param dx: Step size for numerical differentiation. By default, the 99th percentile divided by 1e10. :return: Density """ if dx is None: dx = self.quantile(0.99) / 1e10 if isinstance(t, Iterable): t = np.array(t) # determine (non-negative) evaluation points x1 = np.max([t - dx / 2, np.zeros_like(t)], axis=0) x2 = x1 + dx return (self.cdf(x2) - self.cdf(x1)) / dx
@cached_property def t_max(self) -> float: """ Time until which computations are performed. This is either the end time specified when initializing the distribution or the time until almost sure absorption. """ if self.end_time is not None: return self.end_time t_abs = self._get_absorption_time() if t_abs < self.start_time: raise ValueError( f"Determined time of almost sure absorption ({t_abs:.1f}) " f"is smaller than start time ({self.start_time:.1f}). " "The start time may be too large or the demography not well-defined." ) return t_abs def _get_absorption_time(self) -> float: """ Get a time estimate for when we have reached absorption almost surely. We base this computation on the transition matrix rather than the moments, because here we have a good idea about how likely absorption, and can warn the user if necessary. Stopping the computation when no more rewards are accumulated is not a good idea, as this can happen before almost sure absorption (exponential runaway growth, temporary isolation in different demes). """ i = 0 T = np.eye(self.state_space.k) epoch = self.demography.get_epoch(0) # An extreme demography (population sizes or migration rates differing by more than ~double precision) makes # the absorption-time matrix exponential numerically unreliable. Because scipy's ``expm`` estimates the # matrix one-norm with randomised power iteration, it then becomes intermittently prohibitively slow # (appearing to hang). Detect this up front from the demography (not the rate matrix, whose range can also be # widened by the coalescent model, e.g. multiple-merger models) and fail fast with a clear error. # coalescence rates scale as 1 / pop_size, migration enters at its own rate scales = [1 / v for v in epoch.pop_sizes.values() if v > 0] scales += [v for v in epoch.migration_rates.values() if v > 0] ratio = max(scales) / min(scales) if scales else 1 if ratio > 1e16: raise ValueError( "The demography is too ill-conditioned to reliably compute the time of almost sure absorption: its " f"population sizes and migration rates differ by a factor of {ratio:.1e}. Use less extreme " "parameters, or set the end time manually (see ``Coalescent.end_time``)." ) t = 2 ** int(np.log2(np.mean(list(epoch.pop_sizes.values())))) expansion_factor = 2 t, T, epoch = self._update(t, 0, T, epoch) p = self._cum(T) # multiple time by expansion_factor until we reach p_absorption while p < self.p_absorption and i < self.max_iter: t, T, epoch = self._update(t * expansion_factor, t, T, epoch) p = self._cum(T) if np.isnan(p): self._logger.critical( "Could not reliably find time of almost sure absorption " "as probability of absorption is NaN. " "This is likely due to an ill-conditioned rate matrix. " f"Using time {t:.1f}. " ) i += 1 if i - 1 == self.max_iter: self._logger.warning( "Could not reliably find time of almost sure absorption after maximum number of iterations. " f"Using time {t:.1f} with probability of absorption 1 - {1 - p:.1e}. " "This could be due to numerical imprecision, unreachable states or very large or small " "absorption times. You can set the end time manually (see `Coalescent.end_time`) or increase " "the maximum number of iterations (`TreeHeightDistribution.max_iter`)." ) return t def _empirical_cdf(self, n_samples: int, reward: Reward = None, t: float | Sequence[float] = None) -> np.ndarray: """ Generate an empirical cumulative distribution function (CDF) by sampling from the distribution. :param n_samples: Number of samples to generate. :param reward: Reward function to use for sampling. If not specified, the default reward of the distribution is used. :param t: Values at which to evaluate the CDF. Default to 100 evenly spaced values between 0 and the 99th percentile. :return: Sorted array of sampled total rewards. """ if t is None: t = np.linspace(0, self.tree_height.quantile(0.99), 100) samples = self._sample(n_samples, reward).reshape(n_samples) x = np.sort(samples) y = np.arange(1, n_samples + 1) / n_samples if x.ndim == 1: return np.interp(t, x, y) def _plot_empirical_cdf( self, n_samples: int = 1000, reward: Reward = None, t: float | Sequence[float] = None, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = 'Empirical CDF' ) -> 'plt.Axes': """ Plot the empirical cumulative distribution function (CDF). :param n_samples: Number of samples to generate. :param reward: Reward function to use for sampling. If not specified, the default reward of the distribution is used. :param t: Values at which to evaluate the CDF. Default to 100 evenly spaced values between 0 and the 99th percentile. :param ax: Axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ from .visualization import Visualization if t is None: t = np.linspace(0, self.tree_height.quantile(0.99), 100) y = self._empirical_cdf(n_samples, reward, t) return Visualization.plot( ax=ax, x=t, y=y, xlabel='t', ylabel='F(t)', label=label, file=file, show=show, clear=clear, title=title )
class SFSDistribution(PhaseTypeDistribution, ABC): """ Base class for site-frequency spectrum distributions. """ def __init__( self, state_space: BlockCountingStateSpace, tree_height: TreeHeightDistribution, demography: Demography, reward: Reward = None ): """ Initialize the distribution. :param state_space: Block-counting state space. :param tree_height: The tree height distribution. :param demography: The demography. :param reward: The reward to multiply the SFS reward with. By default, the unit reward is used, which has no effect. """ if reward is None: reward = UnitReward() super().__init__( state_space=state_space, tree_height=tree_height, demography=demography, reward=reward ) #: Generated probability mass by iterator returned from :meth:`get_mutation_configs`. self.generated_mass = 0 @abstractmethod def _get_sfs_reward(self, i: int) -> SFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ pass @abstractmethod def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ pass @staticmethod @abstractmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ pass @_make_hashable @cache def moment( self, k: int, rewards: Sequence[SFSReward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> SFS: """ Get the kth moments of the site-frequency spectrum. :param k: The order of the moment :param rewards: Sequence of k rewards :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: A site-frequency spectrum of kth order moments. """ if rewards is None: rewards = (self.reward,) * k # optionally parallelize the moment computation of the SFS bins moments = parallelize( func=lambda x: self._moment(*x), data=[[k, i, rewards, start_time, end_time, center, permute] for i in self._get_indices()], desc=f"Calculating {k}-moments", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) return SFS([0] + list(moments) + [0] * (self.lineage_config.n - len(moments))) def _moment( self, k: int, i: int, rewards: Sequence[SFSReward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> float: """ Get the kth moment for the ith site-frequency count. :param k: The order of the moment :param i: The ith site-frequency count :param rewards: Sequence of k rewards :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth SFS (cross)-moment at the ith site-frequency count """ return PhaseTypeDistribution.moment( self, k=k, rewards=tuple([CombinedReward([r, self._get_sfs_reward(i)]) for r in rewards]), start_time=start_time, end_time=end_time, center=center, permute=permute ) def accumulate( self, k: int, end_times: Iterable[float], rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Evaluate the kth (non-central) moments for site-frequency spectrum at different end times. :param k: The order of the moment. :param end_times: Times or time when to evaluate the moment. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: Array of moments accumulated at the specified times, one for each site-frequency count. """ k = int(k) indices = self._get_indices() end_times = np.array(list(end_times)) accumulation = parallelize( func=lambda x: self.get_accumulation(*x), data=[[k, i, end_times, rewards] for i in indices], desc=f"Calculating accumulation of {k}-moments", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) # pad with zeros return np.concatenate([ np.zeros((1, len(end_times))), accumulation, np.zeros((self.lineage_config.n - len(indices), len(end_times))) ]) def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = None ) -> 'plt.Axes': """ Plot accumulation of (non-central) SFS moments at different times. .. note:: This is different from a CDF, as it shows the accumulation of moments rather than the probability of having reached absorption at a certain time. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :param ax: The axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ import matplotlib.pyplot as plt from .visualization import Visualization k = int(k) if ax is None: ax = plt.gca() if end_times is None: end_times = np.linspace(0, self.tree_height.quantile(0.99), 200) if rewards is None: rewards = (self.reward,) * k if title is None: title = (f"SFS Moment accumulation " f"({', '.join(r.__class__.__name__.replace('Reward', '') for r in rewards)})") # get accumulation of moments accumulation = self.accumulate(k, end_times, rewards, center, permute) for i, acc in zip(self._get_indices(), accumulation[1: -1]): Visualization.plot( ax=ax, x=end_times, y=acc, xlabel='t', ylabel='moment', label=f'{i}', file=file, show=i == self._get_indices()[-1] and show, clear=clear, title=title ) return ax def get_accumulation( self, k: int, i: int, end_times: Iterable[float] | float, rewards: Sequence[SFSReward] = None, center: bool = True, permute: bool = True ) -> np.ndarray | float: """ Get accumulation of moments for the ith site-frequency count. :param k: The order of the moment :param i: The ith site-frequency count. :param end_times: Times or time when to evaluate the moment. :param rewards: Sequence of k rewards. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth SFS (cross)-moment accumulations at the ith site-frequency count """ if rewards is None: rewards = [self.reward] * k return super().accumulate( k=k, end_times=end_times, rewards=tuple([CombinedReward([r, self._get_sfs_reward(i)]) for r in rewards]), center=center, permute=permute ) @cached_property def cov(self) -> SFS2: """ Covariance matrix across site-frequency counts. """ # create list of arguments for each combination of i, j indices = [(i, j) for i in self._get_indices() for j in self._get_indices()] # get sfs using parallelized function sfs_results = parallelize( func=lambda x: ( PhaseTypeDistribution.moment(self, k=2, permute=False, center=False, rewards=( CombinedReward([self.reward, self._get_sfs_reward(x[0])]), CombinedReward([self.reward, self._get_sfs_reward(x[1])]) )) ), data=indices, desc="Calculating covariance", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) # re-structure the results to a matrix form sfs = np.zeros((self.lineage_config.n + 1, self.lineage_config.n + 1)) for ((i, j), result) in zip(indices, sfs_results): sfs[i, j] = result # get matrix of marginal moments m2 = np.outer(self.mean.data, self.mean.data) # calculate covariances cov = (sfs + sfs.T) / 2 - m2 return SFS2(cov) def get_cov(self, i: int, j: int) -> float: """ Get the covariance between the ith and jth site-frequency. :param i: The ith frequency count :param j: The jth frequency count :return: covariance """ if i in (0, self.lineage_config.n) or j in (0, self.lineage_config.n): return 0 return super().moment( k=2, rewards=( CombinedReward([self.reward, self._get_sfs_reward(i)]), CombinedReward([self.reward, self._get_sfs_reward(j)]) ), center=True ) @cached_property def corr(self) -> SFS2: """ Correlation matrix across site-frequency counts. """ # get standard deviations std = np.sqrt(self.var.data) sfs = SFS2(self.cov.data / np.outer(std, std)) # replace NaNs with zeros sfs.data[np.isnan(sfs.data)] = 0 return sfs def get_corr(self, i: int, j: int) -> float: """ Get the correlation coefficient between the ith and jth site-frequency. :param i: The ith frequency count :param j: The jth frequency count :return: Correlation coefficient """ if i in (0, self.lineage_config.n) or j in (0, self.lineage_config.n): return 0 return self.get_cov(i, j) / (np.sqrt(self.get_cov(i, i)) * np.sqrt(self.get_cov(j, j))) @cache def _get_P(self, n: int, theta: float) -> Tuple[np.ndarray, np.ndarray]: """ Get transition matrix for mutational configuration probabilities. :param n: The number of frequency bins. :param theta: The mutation rate. :return: Transition matrix and exit vector. """ # get non-absorbing states non_absorbing = TreeHeightReward()._get(self.state_space).astype(bool) e = self.state_space.e[non_absorbing] R = np.array([self._get_sfs_reward(i)._get(self.state_space) for i in range(1, n + 1)])[:, non_absorbing] r_total = R.T @ np.ones(n) S = self.state_space.S[non_absorbing, :][:, non_absorbing] I = np.eye(S.shape[0]) P_total = np.linalg.inv(I - np.diag(1 / r_total) / theta @ S) p_total = (I - P_total) @ e P = np.array([P_total @ np.diag(R[i] / r_total) for i in range(n)]) return P, p_total def get_mutation_config(self, config: Sequence[int], theta: float) -> float: """ Get the probabilities of observing the given mutational configurations according to the infinite sites model. .. note:: This currently only works for a single epoch, i.e. a time-homogeneous demography, and recombination is not supported. :param config: The mutational configuration. A sequence of integers of length n - 1 for unfolded configurations and n // 2 for folded configurations, where n is the number of lineages. Each element in the sequence is an integer representing the number of mutations at each frequency count starting from 1. For example, the unfolded configuration [2, 1, 0] represents two singleton, one doubleton and zero tripleton mutations for a sample size of 4 lineages. Similarly, the folded configuration [2, 1] represents two singleton or tripleton and one doubleton mutation for the same number of lineages. :param theta: The mutation rate. :return: The probability of observing the given mutational configuration. """ # raise not implemented error if more than one epoch if self.demography.has_n_epochs(2): raise NotImplementedError("Sampling not implemented for more than one epoch.") # make sure theta is non-negative if theta < 0: raise ValueError("Theta must be greater than or equal to 0.") # number of frequency bins n = len(self._get_configs(self.lineage_config.n, 0)[0]) if len(config) != n: raise ValueError( "The length of the configuration must be equal to the number of frequency bins. " f"Expected {n}, got {len(config)}." ) # explicitly convert to tuple of integers config = tuple(int(c) for c in config) # handle special case when theta = 0 if theta == 0: if sum(config) == 0: return 1 return 0 # get non-absorbing states non_absorbing = TreeHeightReward()._get(self.state_space).astype(bool) # number of non-absorbing states k = non_absorbing.sum() alpha = self.state_space.alpha[non_absorbing] P, p_total = self._get_P(n, theta) q = list(itertools.chain(*[[i + 1] * j for i, j in enumerate(config)])) # iterate over permutations of q Q = np.zeros((k, k)) for p in multiset_permutations(q): U = np.eye(k) for i in p: U @= P[i - 1] Q += U p = alpha @ Q @ p_total return p def get_mutation_configs(self, theta: float) -> Iterator[Tuple[Tuple[float, ...], float]]: """ An iterator over the probabilities of observing mutational configurations according to the infinite sites model. The order of the mutational configurations generated ascends in the number of mutations observed. See :meth:`get_mutation_config` for more information on mutational configurations. .. note:: This currently only works for a single epoch, i.e. a time-homogeneous demography, and recombination is not supported. Also note that the number of configurations is infinite, so this iterator will never stop. However, depending on the mutation rate, the probability of observing configurations of higher mutation counts will decrease over time. You can keep track of the generated probability mass by checking the :attr:`~.generated_mass` attribute, which is reset every time this method is called. A good approach is thus to keep generating configurations until the generated mass is above a certain threshold. More complex demographic models and larger sample sizes increase the number of configurations and higher mutation rates, the number of generated configurations necessary to reach a certain mass. Code example: :: coal = pg.Coalescent(n=5) it = coal.sfs.get_mutation_configs(theta=1) # continue until generated mass is above 0.8 samples = list(pg.takewhile_inclusive(lambda _: coal.sfs.generated_mass < 0.8, it)) :param theta: The mutation rate. :return: An iterator over the probabilities of observing mutational configurations. """ # reset generated mass self.generated_mass = 0 # iterate over number of mutations i = 0 while True: # iterate over configurations for config in self._get_configs(self.lineage_config.n, i): p = self.get_mutation_config(config=config, theta=theta) self.generated_mass += p yield config, p # increase counter for number of mutations i += 1 class TajimaSFSMixin: """ Mixin providing the branch-length diversity estimators and Tajima's :math:`D` from the site-frequency spectrum mean and covariance. Shared by the analytical :class:`UnfoldedSFSDistribution` and the simulation-based empirical SFS distribution, so the same statistics can be computed from either source. Subclasses supply :meth:`_tajima_n`, :meth:`_tajima_mean` and :meth:`_tajima_cov`. """ def _tajima_n(self) -> int: """Number of lineages.""" raise NotImplementedError def _tajima_mean(self) -> np.ndarray: """Mean branch length per polymorphic SFS bin (``i = 1 .. n-1``).""" raise NotImplementedError def _tajima_cov(self) -> np.ndarray: """Covariance of the polymorphic SFS bins (``i, j = 1 .. n-1``).""" raise NotImplementedError @cached_property def _tajima_weights(self) -> Tuple[np.ndarray, np.ndarray]: """Per-bin weights for the two diversity estimators: pairwise diversity ``pi`` and Watterson's ``theta_W``.""" n = self._tajima_n() i = np.arange(1, n) w_pi = 2 * i * (n - i) / (n * (n - 1)) w_w = np.full(n - 1, 1 / np.sum(1 / i)) return w_pi, w_w @cached_property def theta_pi(self) -> float: r""" Mean pairwise diversity :math:`\pi = \sum_i \frac{2 i (n - i)}{n (n - 1)} \mathbb{E}[L_i]`, the branch-length estimator of :math:`\theta` based on the expected number of pairwise differences. """ w_pi, _ = self._tajima_weights return float(w_pi @ self._tajima_mean()) @cached_property def theta_w(self) -> float: r""" Watterson's estimator :math:`\theta_W = L_\text{total} / a_n` with :math:`a_n = \sum_{k=1}^{n-1} 1/k`, the branch-length estimator of :math:`\theta` based on the total branch length. """ _, w_w = self._tajima_weights return float(w_w @ self._tajima_mean()) @cached_property def tajimas_d(self) -> float: r""" Tajima's :math:`D` in branch form: :math:`D = (\pi - \theta_W) / \sqrt{c^\top \, \mathrm{Cov}[L] \, c}` with weights :math:`c_i = \frac{2 i (n - i)}{n (n - 1)} - 1/a_n`. It is ``0`` under the standard neutral constant-size model, negative under population growth (excess of low-frequency variants) and positive under contraction. The normalization uses the branch-length covariance rather than the mutation-based variance of the classical sample estimator. """ w_pi, w_w = self._tajima_weights c = w_pi - w_w num = c @ self._tajima_mean() var = c @ self._tajima_cov() @ c if var <= 0: return 0.0 return float(num / np.sqrt(var))
[docs] class UnfoldedSFSDistribution(SFSDistribution, TajimaSFSMixin): """ Unfolded site-frequency spectrum distribution. """ def _get_sfs_reward(self, i: int) -> UnfoldedSFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ return UnfoldedSFSReward(i) def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ return np.arange(1, self.lineage_config.n) def _tajima_n(self) -> int: return self.lineage_config.n def _tajima_mean(self) -> np.ndarray: n = self.lineage_config.n return np.asarray(self.mean.data)[1:n] def _tajima_cov(self) -> np.ndarray: n = self.lineage_config.n return np.asarray(self.cov.data)[1:n, 1:n] @staticmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ return StateSpace._get_partitions(n=k, k=n - 1)
[docs] class FoldedSFSDistribution(SFSDistribution): """ Folded site-frequency spectrum distribution. """ def _get_sfs_reward(self, i: int) -> FoldedSFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ return FoldedSFSReward(i) def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ return np.arange(1, self.lineage_config.n // 2 + 1) @staticmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ return StateSpace._get_partitions(n=k, k=n // 2) def _unfold(self, config: Sequence[int]) -> Set[Tuple[int, ...]]: """ Unfold a folded configuration into all possible unfolded configurations. :param config: The folded configuration. A sequence of integers of length n // 2 where n is the number of lineages. :return: The unfolded configurations. """ n = self.lineage_config.n if n // 2 != len(config): raise ValueError("The length of the configuration must equal n // 2 where n is the number of lineages.") if n % 2 == 1: lower_counts = [range(i + 1) for i in config] i_center = len(config) else: lower_counts = [range(i + 1) for i in config[:-1]] + [[config[-1]]] i_center = len(config) - 1 unfolded = [] # iterate over unfolded configurations for lower in itertools.product(*lower_counts): # get higher counts higher = (np.array(config) - np.array(lower))[:i_center][::-1] unfolded += [list(lower) + list(higher)] return set(tuple(u) for u in unfolded)
[docs] class JointSFSDistribution(PhaseTypeDistribution): """ Joint (multi-population) site-frequency spectrum distribution. Moments are returned as a multi-dimensional array of shape ``(n_0 + 1, ..., n_{P-1} + 1)``, where ``n_p`` is the sample size of population ``p``. The entry at index ``(k_0, ..., k_{P-1})`` is the moment for branches subtending exactly ``k_p`` samples from population ``p``. The monomorphic bins (the all-zero and the full ``(n_0,...,n_{P-1})`` configuration) are zero by convention. """
[docs] def __init__( self, state_space: JointBlockCountingStateSpace, tree_height: 'TreeHeightDistribution', demography: Demography, reward: Reward = None ): """ Initialize the distribution. :param state_space: Joint block-counting state space. :param tree_height: The tree height distribution. :param demography: The demography. :param reward: The reward to multiply the joint SFS reward with. By default, the unit reward is used, which has no effect. """ if reward is None: reward = UnitReward() super().__init__( state_space=state_space, tree_height=tree_height, demography=demography, reward=reward )
@cached_property def shape(self) -> Tuple[int, ...]: """ Shape of the joint SFS array, ``(n_0 + 1, ..., n_{P-1} + 1)``. """ return tuple(int(n_p) + 1 for n_p in self.lineage_config.lineages) def _get_configs(self) -> List[Tuple[int, ...]]: """ Get the descendant vectors corresponding to (polymorphic) joint SFS bins, i.e. all block configurations except the full-sample configuration (which corresponds to the monomorphic, fixed sites). :return: List of descendant vectors. """ full = tuple(int(n_p) for n_p in self.lineage_config.lineages) return [c for c in self.state_space.block_configs if c != full]
[docs] def moment( self, k: int, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Get the kth moments of the joint site-frequency spectrum. :param k: The order of the moment. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. :return: An array of shape :attr:`shape` holding the kth moment of each joint SFS bin. """ # like the base distribution, a moment is the accumulation over the [start_time, end_time] window if start_time is None: start_time = self.tree_height.start_time if end_time is None: end_time = self.tree_height.t_max if start_time > 0: acc = self.accumulate(k, [start_time, end_time], center=center, permute=permute) out = acc[..., 1] - acc[..., 0] else: out = self.accumulate(k, [end_time], center=center, permute=permute)[..., 0] if np.isnan(out).any(): raise ValueError( "NaN value encountered when computing moment. " "This is likely due to an ill-conditioned rate matrix." ) return JointSFS(out)
[docs] def accumulate( self, k: int, end_times: Iterable[float], center: bool = True, permute: bool = True ) -> np.ndarray: """ Evaluate the kth moments of the joint site-frequency spectrum at different end times. :param k: The order of the moment. :param end_times: Times when to evaluate the moments. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. :return: Array of shape :attr:`shape` ``+ (len(end_times),)`` with each bin's kth moment over time. """ k = int(k) configs = self._get_configs() end_times = np.array(list(end_times)) accumulation = parallelize( func=lambda config: PhaseTypeDistribution.accumulate( self, k=k, end_times=end_times, rewards=tuple(CombinedReward([self.reward, JointSFSReward(config)]) for _ in range(k)), center=center, permute=permute ), data=configs, desc=f"Calculating accumulation of {k}-moments", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) out = np.zeros(self.shape + (len(end_times),)) for config, acc in zip(configs, accumulation): out[config] = acc return out
[docs] def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, title: str = None ) -> 'plt.Axes': """ Plot accumulation of joint SFS moments over time, one curve per (polymorphic) bin. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. Defaults to 200 points up to the 99th percentile. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. :param ax: The axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param title: Title of the plot. :return: Axes. """ import matplotlib.pyplot as plt from .visualization import Visualization k = int(k) if ax is None: ax = plt.gca() if end_times is None: end_times = np.linspace(0, self.tree_height.quantile(0.99), 200) end_times = np.asarray(list(end_times)) if title is None: title = f"Joint SFS moment accumulation (order {k})" configs = self._get_configs() accumulation = self.accumulate(k, end_times, center=center, permute=permute) for i, config in enumerate(configs): Visualization.plot( ax=ax, x=end_times, y=accumulation[config], xlabel='t', ylabel='moment', label=str(config), file=file, show=(i == len(configs) - 1) and show, clear=clear, title=title ) return ax
@cached_property def mean(self) -> JointSFS: """ Mean of the joint site-frequency spectrum, i.e. the expected branch length subtending each descendant configuration. """ return self.moment(k=1) @cached_property def var(self) -> JointSFS: """ Variance of the joint site-frequency spectrum. """ return self.moment(k=2, center=True)
[docs] def get_cov(self, config_a: Tuple[int, ...], config_b: Tuple[int, ...]) -> float: """ Get the covariance between the branch lengths subtending two descendant configurations. :param config_a: First descendant configuration. :param config_b: Second descendant configuration. :return: The covariance. """ return PhaseTypeDistribution.moment( self, k=2, center=True, rewards=tuple(CombinedReward([self.reward, JointSFSReward(c)]) for c in (config_a, config_b)) )
@cached_property def cov(self) -> np.ndarray: """ Covariance between the branch lengths of all pairs of (polymorphic) joint SFS bins. Returned as an array of shape :attr:`shape` ``+`` :attr:`shape`, where ``cov[a_0, ..., a_{P-1}, b_0, ..., b_{P-1}]`` is the covariance between bins ``(a_0, ..., a_{P-1})`` and ``(b_0, ..., b_{P-1})``. """ configs = self._get_configs() pairs = [(a, b) for a in configs for b in configs] results = parallelize( func=lambda ab: self.get_cov(*ab), data=pairs, desc="Calculating covariance", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) out = np.zeros(self.shape + self.shape) for (a, b), result in zip(pairs, results): out[tuple(a) + tuple(b)] = result return out
[docs] class TwoLocusSFSDistribution(PhaseTypeDistribution): """ Two-locus site-frequency spectrum under recombination. Entry ``(i, j)`` of the (symmetrized) mean is ``E[L^0_i · L^1_j]`` — the expected product of the branch length subtending ``i`` samples at locus 0 and ``j`` samples at locus 1 — computed as a second cross-moment of two per-locus SFS rewards on the two-locus block-counting state space. It reduces to ``Coalescent.sfs.cov`` (plus the outer product of the marginal means) as ``r → 0`` and to the outer product of the marginal SFS as ``r → ∞``. """
[docs] def __init__( self, state_space: TwoLocusBlockCountingStateSpace, tree_height: 'TreeHeightDistribution', demography: Demography, reward: Reward = None ): """ Initialize the distribution. :param state_space: Two-locus block-counting state space. :param tree_height: The (two-locus) tree height distribution, whose absorption time is when both loci have reached their MRCA. :param demography: The demography. :param reward: An optional reward to multiply the per-locus SFS rewards with. By default the unit reward. """ if reward is None: reward = UnitReward() super().__init__(state_space=state_space, tree_height=tree_height, demography=demography, reward=reward)
@cached_property def shape(self) -> Tuple[int, ...]: """ Shape of the two-locus SFS array, ``(n + 1, n + 1)`` (one axis per locus). """ n = int(self.lineage_config.n) return n + 1, n + 1 def _get_indices(self) -> List[int]: """ Polymorphic SFS bins ``1, ..., n - 1`` (the monomorphic ``0`` and ``n`` bins carry no information). """ return list(range(1, self.lineage_config.n)) @cached_property def mean(self) -> TwoLocusSFS: """ Mean two-locus SFS, ``E[L^0_i · L^1_j]`` for all polymorphic bins, symmetrized over the two loci. """ n = self.lineage_config.n indices = [(i, j) for i in self._get_indices() for j in self._get_indices()] results = parallelize( func=lambda x: PhaseTypeDistribution.moment( self, k=2, permute=False, center=False, rewards=( CombinedReward([self.reward, TwoLocusSFSReward(0, x[0])]), CombinedReward([self.reward, TwoLocusSFSReward(1, x[1])]) ) ), data=indices, desc="Calculating two-locus SFS", pbar=Settings.use_pbar, parallelize=Settings.parallelize ) out = np.zeros((n + 1, n + 1)) for (i, j), result in zip(indices, results): out[i, j] = result # symmetrize over the two (exchangeable) loci, as for the single-locus SFS covariance return TwoLocusSFS((out + out.T) / 2)
class EmpiricalJointSFSDistribution: # pragma: no cover """ Empirical (msprime-based) joint site-frequency spectrum, exposing the same ``mean``/``var``/``m2``/``m3`` interface as :class:`JointSFSDistribution` so that the two can be compared by :class:`~phasegen.comparison.Comparison`. The moments are pre-computed arrays (so the object can be serialized as cached ground truth). """ def __init__(self, moments: np.ndarray): """ Initialize the distribution. :param moments: Per-configuration (non-central) moments of orders ``1, 2, ...``, stacked along the first axis, i.e. an array of shape ``(max_order, n_0 + 1, ..., n_{P-1} + 1)``. """ #: Non-central moments per descendant configuration, indexed by order minus one. self._moments: np.ndarray = np.asarray(moments) @property def mean(self) -> JointSFS: """ Mean of the joint site-frequency spectrum. """ return JointSFS(self._moments[0]) @property def m2(self) -> JointSFS: """ Second (non-central) moment of the joint site-frequency spectrum. """ return JointSFS(self._moments[1]) @property def m3(self) -> JointSFS: """ Third (non-central) moment of the joint site-frequency spectrum. """ return JointSFS(self._moments[2]) @property def var(self) -> JointSFS: """ Variance of the joint site-frequency spectrum. """ return JointSFS(self._moments[1] - self._moments[0] ** 2) @property def data(self) -> np.ndarray: """ The mean joint site-frequency spectrum array. """ return self._moments[0] class EmpiricalDistribution(DensityAwareDistribution): # pragma: no cover """ Probability distribution based on realisations. """ def __init__(self, samples: np.ndarray | list): """ Create object. :param samples: 1-D array of samples. """ super().__init__() self._cache = None #: Samples self.samples = np.array(samples, dtype=float) def touch(self, t: np.ndarray): """ Touch all cached properties. :param t: Times to cache properties for. """ super().touch() self._cache = dict( t=t, cdf=self.cdf(t), pdf=self.pdf(t) ) def drop(self): """ Drop simulated samples. """ self.samples = None @cached_property def mean(self) -> float | np.ndarray: """ First moment / mean. """ return np.mean(self.samples, axis=0) @cached_property def var(self) -> float | np.ndarray: """ Second central moment / variance. """ return np.var(self.samples, axis=0) @cached_property def m2(self) -> float | np.ndarray: """ Second non-central moment. """ return np.mean(self.samples ** 2, axis=0) @cached_property def m3(self) -> float | np.ndarray: """ Third non-central moment. """ return np.mean(self.samples ** 3, axis=0) @cached_property def m4(self) -> float | np.ndarray: """ Fourth non-central moment. """ return np.mean(self.samples ** 4, axis=0) @cached_property def cov(self) -> float | np.ndarray: """ Covariance matrix. """ return np.nan_to_num(np.cov(self.samples, rowvar=False)) @cached_property def corr(self) -> float | np.ndarray: """ Correlation matrix. """ return np.nan_to_num(np.corrcoef(self.samples, rowvar=False)) def moment(self, k: int) -> float | np.ndarray: """ Get the kth moment. :param k: The order of the moment :return: The kth moment """ return np.mean(self.samples ** k, axis=0) def cdf(self, t: float | Sequence[float]) -> float | np.ndarray: """ Cumulative distribution function. :param t: Time. :return: Cumulative probability. """ x = np.sort(self.samples) y = np.arange(1, len(self.samples) + 1) / len(self.samples) if x.ndim == 1: return np.interp(t, x, y) if x.ndim == 2: return np.array([np.interp(t, x_, y) for x_ in x.T]) raise ValueError("Samples must be 1 or 2 dimensional.") def quantile(self, q: float) -> float: """ Get the qth quantile. :param q: Quantile. :return: Quantile. """ return np.quantile(self.samples, q=q) def pdf( self, t: float | np.ndarray, n_bins: int = 10000, sigma: float = None, samples: np.ndarray = None, **kwargs: dict ) -> float | np.ndarray: """ Density function. :param sigma: Sigma for Gaussian filter. :param n_bins: Number of bins. :param t: Time. :param samples: Samples. :return: Density. """ samples = self.samples if samples is None else samples if samples.ndim == 2: return np.array([self.pdf(t, n_bins=n_bins, sigma=sigma, samples=s) for s in samples.T]) hist, bin_edges = np.histogram(samples, range=(0, max(samples)), bins=n_bins, density=True) # determine bins for u bins = np.minimum(np.sum(bin_edges <= t[:, None], axis=1) - 1, np.full_like(t, n_bins - 1, dtype=int)) # use proper bins for y values y = hist[bins] # smooth using gaussian filter if sigma is not None: y = gaussian_filter1d(y, sigma=sigma) return y class EmpiricalSFSDistribution(EmpiricalDistribution): # pragma: no cover """ SFS probability distribution based on realisations. """ def __init__(self, samples: np.ndarray | list): """ Create object. :param samples: 2-D array of samples. """ super().__init__(samples) @cached_property def mean(self) -> SFS: """ First moment / mean. """ return SFS(super().mean) @cached_property def var(self) -> SFS: """ Second central moment / variance. """ return SFS(super().var) @cached_property def m2(self) -> SFS: """ Second non-central moment. """ return SFS(super().m2) @cached_property def cov(self) -> SFS2: """ Covariance matrix. """ return SFS2(np.nan_to_num(np.cov(self.samples, rowvar=False))) @cached_property def corr(self) -> SFS2: """ Correlation matrix. """ return SFS2(np.nan_to_num(np.corrcoef(self.samples, rowvar=False))) class DictContainer(dict): # pragma: no cover """ Dictionary container. """ pass class EmpiricalPhaseTypeDistribution(EmpiricalDistribution): # pragma: no cover """ Phase-type distribution based on realisations. """ def __init__( self, samples: np.ndarray | list, pops: List[str], locus_agg: Callable = lambda x: x.sum(axis=0) ): """ Create object. :param samples: 3-D array of samples. :param pops: List of population names. :param locus_agg: Aggregation function for loci. """ over_loci = locus_agg(samples).astype(float) over_demes = samples.sum(axis=1).astype(float) super().__init__(over_loci.sum(axis=0)) #: Population names self.pops = pops #: Samples by deme and locus self._samples = samples #: Covariance matrix for the demes self.pops_cov: np.ndarray = np.cov(over_loci) #: Correlation matrix for the demes self.pops_corr: np.ndarray = np.corrcoef(over_loci) #: Covariance matrix for the loci self.loci_corr: np.ndarray = np.corrcoef(over_demes) #: Correlation matrix for the loci self.loci_cov: np.ndarray = np.cov(over_demes) def touch(self, t: np.ndarray): """ Touch all cached properties. :param t: Times to cache properties for. """ super().touch(t) [d.touch(t) for d in self.demes.values()] [l.touch(t) for l in self.loci.values()] def drop(self): """ Drop simulated samples. """ super().drop() self._samples = None [d.drop() for d in self.demes.values()] [l.drop() for l in self.loci.values()] @cached_property def demes(self) -> Dict[str, EmpiricalDistribution]: """ Get the distribution for each deme. :return: Dictionary of distributions. """ demes = DictContainer( {pop: EmpiricalDistribution(self._samples.sum(axis=0)[i]) for i, pop in enumerate(self.pops)} ) # TODO this is the covariance in the tree height but phasegen # provides the covariance in the number of lineages per deme demes.cov = self.pops_cov demes.corr = self.pops_corr return demes @cached_property def loci(self) -> Dict[int, EmpiricalDistribution]: """ Get the distribution for each locus. :return: Dictionary of distributions. """ loci = DictContainer( {i: EmpiricalDistribution(self._samples[i].sum(axis=0)) for i in range(self._samples.shape[0])} ) loci.cov = self.loci_cov loci.corr = self.loci_corr return loci class EmpiricalPhaseTypeSFSDistribution(EmpiricalPhaseTypeDistribution, TajimaSFSMixin): # pragma: no cover """ SFS phase-type distribution based on realisations. """ def _tajima_n(self) -> int: # derive n from the (serialized) mean vector so this works on fixtures restored without ``n`` return len(np.asarray(self.mean)) - 1 def _tajima_mean(self) -> np.ndarray: n = self._tajima_n() return np.asarray(self.mean)[1:n] def _tajima_cov(self) -> np.ndarray: n = self._tajima_n() return np.asarray(self.cov)[1:n, 1:n] def __init__( self, branch_lengths: np.ndarray, mutations: np.ndarray, pops: List[str], sfs_dist: Type[SFSDistribution], locus_agg: Callable = lambda x: x.sum(axis=0), ): """ Create object. :param branch_lengths: 4-D array of branch length samples. :param mutations: 4-D array of mutation counts. :param pops: List of population names. :param sfs_dist: SFS distribution class. :param locus_agg: Aggregation function for loci. """ over_loci = locus_agg(branch_lengths).astype(float) EmpiricalDistribution.__init__(self, over_loci.sum(axis=0)) #: Population names self.pops = pops # : Number of lineages self.n = branch_lengths.shape[-1] - 1 #: SFS distribution class self._sfs_dist = sfs_dist #: Branch length samples by deme and locus self._samples = branch_lengths #: Mutation counts by deme and locus self._mutations = mutations #: Correlation matrix for the loci self.pops_corr = self._get_stat_pops(over_loci, np.corrcoef) #: Covariance matrix for the demes self.pops_cov: np.ndarray = self._get_stat_pops(over_loci, np.cov) #: Correlation matrix for the loci self.loci_corr: np.ndarray = None #: Covariance matrix for the loci self.loci_cov: np.ndarray = None #: Generated probability mass by iterator returned from :meth:`get_mutation_configs`. self.generated_mass = 0 def drop(self): """ Drop simulated samples. """ super().drop() self._mutations = None @staticmethod def _get_stat_pops(samples: np.ndarray, callback: Callable) -> np.ndarray: """ Get the covariance matrix for the demes. :param callback: Callback function to apply to the samples. :return: Covariance matrix. """ stats = np.zeros((samples.shape[0], samples.shape[0], samples.shape[2], samples.shape[2])) for i, j in itertools.product(range(1, samples.shape[2] - 1), range(1, samples.shape[2] - 1)): stats[:, :, i, j] = callback(samples[:, :, i]) return stats @cached_property def demes(self) -> Dict[str, EmpiricalDistribution]: """ Get the distribution for each deme. :return: Dictionary of distributions. """ return {pop: EmpiricalSFSDistribution(self._samples.sum(axis=0)[i]) for i, pop in enumerate(self.pops)} @cached_property def mutation_configs(self) -> Dict[Tuple[float, ...], float]: """ Get a dictionary of all mutation configurations and their probabilities. :return: Dictionary of distributions. """ configs = defaultdict(lambda: 0) for config in self._mutations[0, 0]: configs[tuple(config)] += 1 / self._mutations.shape[2] return configs def get_mutation_config(self, config: Sequence[int]) -> float: """ Get the probability of observing the given mutational configuration. :param config: The mutational configuration. :return: The probability of observing the given mutational configuration. """ return self.mutation_configs[tuple(config)] def get_mutation_configs(self) -> Iterator[Tuple[Tuple[float, ...], float]]: """ An iterator over the probabilities of observing mutational configurations according to the infinite sites model. The order of the mutational configurations generated ascends in the number of mutations observed. :return: An iterator over the probabilities of observing mutational configurations. """ # reset generated mass self.generated_mass = 0 # iterate over number of mutations i = 0 while True: # iterate over configurations for config in self._sfs_dist._get_configs(self.n, i): p = self.get_mutation_config(config=config) self.generated_mass += p yield config, p # increase counter for number of mutations i += 1 class AbstractCoalescent(ABC): """ Abstract base class for coalescent distributions. This class provides probability distributions for the tree height, total branch length and site frequency spectrum. """ def __init__( self, n: int | Dict[str, int] | List[int] | LineageConfig, model: CoalescentModel = None, demography: Demography = None, loci: int | LocusConfig = 1, recombination_rate: float = None, end_time: float = None ): """ Create object. :param n: Number of lineages. Either a single integer if only one population, or a list of integers or a dictionary with population names as keys and number of lineages as values. Alternatively, a :class:`~phasegen.lineage.LineageConfig` object can be passed. :param model: Coalescent model. By default, the standard coalescent is used. :param loci: Number of loci or locus configuration. :param recombination_rate: Recombination rate. :param demography: Demography. :param end_time: Time when to end the computation. If ``None``, the end time is end time is taken to be the time of almost sure absorption. Note that unnecessarily large end times can lead to numerical errors. """ self._logger = logger.getChild(self.__class__.__name__) # set up default coalescent model if model is None: model = StandardCoalescent() if not isinstance(n, LineageConfig): #: Population configuration self.lineage_config: LineageConfig = LineageConfig(n) else: #: Population configuration self.lineage_config: LineageConfig = n # set up demography if demography is None: demography = Demography(pop_sizes={p: 1 for p in self.lineage_config.pop_names}) # set up locus configuration (accept a numeric number of loci, including the float that reticulate passes # from R, or a LocusConfig) if isinstance(loci, (int, float)): #: Locus configuration self.locus_config: LocusConfig = LocusConfig( n=int(loci), recombination_rate=recombination_rate if recombination_rate is not None else 0 ) else: #: Locus configuration self.locus_config: LocusConfig = loci if recombination_rate is not None: self.locus_config.recombination_rate = recombination_rate # population names present in the population configuration but not in the demography initial_sizes = {p: {0: 1} for p in self.lineage_config.pop_names if p not in demography.pop_names} # add missing population sizes to demography if len(initial_sizes) > 0: demography.add_event( PopSizeChanges(initial_sizes) ) # warn if population names are present in the population configuration but not in the demography self._logger.warning( f"The following population names are present in the population configuration but not " f"in the demography: {list(initial_sizes.keys())}. " f"Adding these populations with population size of 1." ) # determine population names that are present in the demography but not in the population configuration unspecified_lineages = set(demography.pop_names) - set(self.lineage_config.pop_names) # warn if population names are present in the demography but not in the population configuration if len(unspecified_lineages) > 0: self._logger.warning( f"The following population names are present in the demography but not " f"in the population configuration: {list(unspecified_lineages)}. " f"Adding these populations with 0 lineages." ) self.lineage_config = LineageConfig(self.lineage_config.lineage_dict | {p: 0 for p in unspecified_lineages}) #: Coalescent model self.model: CoalescentModel = model #: Demography self.demography: Demography = demography #: End time self.end_time: float = end_time @property @abstractmethod def tree_height(self) -> DensityAwareDistribution: """ Tree height distribution. """ pass @property @abstractmethod def total_branch_length(self) -> MomentAwareDistribution: """ Total branch length distribution. """ pass @property @abstractmethod def sfs(self) -> MomentAwareDistribution: """ Unfolded site-frequency spectrum distribution. """ pass @property @abstractmethod def fsfs(self) -> MomentAwareDistribution: """ Folded site-frequency spectrum distribution. """ pass
[docs] class Coalescent(AbstractCoalescent, Serializable): """ Coalescent distribution. """
[docs] def __init__( self, n: int | Dict[str, int] | List[int] | LineageConfig, model: CoalescentModel = None, demography: Demography = None, loci: int | LocusConfig = 1, recombination_rate: float = None, start_time: float = 0, end_time: float = None, ): """ Create object. :param n: Number of lineages. Either a single integer if only one population, or a list of integers or dictionary with population names as keys and number of lineages as values for multiple populations. Alternatively, a :class:`~phasegen.lineage.LineageConfig` object can be passed. :param model: Coalescent model. Default is the standard coalescent. :param demography: Demography. :param loci: Number of loci or locus configuration. :param recombination_rate: Recombination rate. :param start_time: Time when to start accumulating moments. By default, this is 0. :param end_time: Time when to end the accumulating moments. If ``None``, the end time is taken to be the time of almost sure absorption. Note that unnecessarily long end times can lead to numerical errors. """ super().__init__( n=n, model=model, loci=loci, recombination_rate=recombination_rate, demography=demography, end_time=end_time ) #: Time when to start accumulating moments self.start_time: float = start_time
@cached_property def lineage_counting_state_space(self) -> LineageCountingStateSpace: """ The lineage-counting state space. """ return LineageCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.demography.get_epoch(0) ) @cached_property def block_counting_state_space(self) -> BlockCountingStateSpace: """ The block-counting state space. """ return BlockCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.demography.get_epoch(0) ) @cached_property def joint_block_counting_state_space(self) -> JointBlockCountingStateSpace: """ The joint block-counting state space (tracks the deme-of-origin composition of each lineage). """ return JointBlockCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.demography.get_epoch(0) ) @cached_property def tree_height(self) -> TreeHeightDistribution: """ Tree height distribution. """ return TreeHeightDistribution( state_space=self.lineage_counting_state_space, demography=self.demography, start_time=self.start_time, end_time=self.end_time ) @cached_property def total_branch_length(self) -> PhaseTypeDistribution: """ Total branch length distribution. """ return PhaseTypeDistribution( reward=TotalBranchLengthReward(), tree_height=self.tree_height, state_space=self.lineage_counting_state_space, demography=self.demography ) def _require_single_locus(self, name: str): """ Raise a clear error if more than one locus is configured for a single-locus SFS statistic. :param name: Name of the statistic, used in the error message. :raises ValueError: if more than one locus is configured. """ if self.locus_config.n != 1: raise ValueError( f"`{name}` is the single-locus site-frequency spectrum and is defined for one locus only " f"(got {self.locus_config.n}). For two loci under recombination use `sfs2` (the two-locus SFS); " f"the single-locus marginal is recombination-invariant, so drop the extra locus to obtain it." ) @cached_property def sfs(self) -> UnfoldedSFSDistribution: """ Unfolded site-frequency spectrum distribution. Defined for a single locus; for two loci under recombination use :meth:`sfs2`. """ self._require_single_locus('sfs') return UnfoldedSFSDistribution( state_space=self.block_counting_state_space, tree_height=self.tree_height, demography=self.demography ) @cached_property def fsfs(self) -> FoldedSFSDistribution: """ Folded site-frequency spectrum distribution. Defined for a single locus; for two loci under recombination use :meth:`sfs2`. """ self._require_single_locus('fsfs') return FoldedSFSDistribution( state_space=self.block_counting_state_space, tree_height=self.tree_height, demography=self.demography ) @cached_property def jsfs(self) -> JointSFSDistribution: """ Joint (multi-population) site-frequency spectrum distribution. Moments are returned as a multi-dimensional array of shape ``(n_0 + 1, ..., n_{P-1} + 1)``. .. note:: The joint state space grows combinatorially with the per-population sample sizes, so this is only practical for small samples. """ return JointSFSDistribution( state_space=self.joint_block_counting_state_space, tree_height=self.tree_height, demography=self.demography ) @cached_property def two_locus_block_counting_state_space(self) -> TwoLocusBlockCountingStateSpace: """ The two-locus block-counting state space (tracks each lineage's descendant counts at both loci and the recombination/linkage history). Requires exactly two loci and a single population. """ return TwoLocusBlockCountingStateSpace( lineage_config=self.lineage_config, locus_config=self.locus_config, model=self.model, epoch=self.demography.get_epoch(0) ) @cached_property def _two_locus_tree_height(self) -> TreeHeightDistribution: """ Tree height of the two-locus process, absorbed once *both* loci have reached their MRCA. """ return TreeHeightDistribution( state_space=self.two_locus_block_counting_state_space, demography=self.demography, start_time=self.start_time, end_time=self.end_time ) @cached_property def sfs2(self) -> TwoLocusSFSDistribution: """ Two-locus site-frequency spectrum under recombination, returned as a :class:`~phasegen.spectrum.TwoLocusSFS`. Requires exactly two loci (``loci=2``) and a single population. .. note:: The two-locus state space grows quickly with the sample size, so this is only practical for small ``n``. """ return TwoLocusSFSDistribution( state_space=self.two_locus_block_counting_state_space, tree_height=self._two_locus_tree_height, demography=self.demography ) @cached_property def fst(self) -> float: r""" Hudson's fixation index :math:`F_{ST} = 1 - \mathbb{E}[T_S] / \mathbb{E}[T_B]`, based on pairwise coalescence times: :math:`T_S` is the coalescence time of two lineages sampled within the same population (averaged over populations) and :math:`T_B` of two lineages from different populations (averaged over population pairs). Requires at least two populations. Since :math:`F_{ST}` is a pairwise, single-locus quantity, it is computed from two-lineage sub-coalescents under the same (possibly time-varying, migrating) demography and coalescent model, and so does not depend on the configured sample sizes or number of loci. :return: Hudson's :math:`F_{ST}`. :raises ValueError: if fewer than two populations are configured. """ pops = self.demography.pop_names if len(pops) < 2: raise ValueError(f"F_ST requires at least two populations (got {len(pops)}).") # within-population pairwise times (both lineages in the same population) t_within = [self._pairwise_coalescence_time(q, q) for q in pops] # between-population pairwise times (one lineage in each of two distinct populations) t_between = [ self._pairwise_coalescence_time(a, b) for i, a in enumerate(pops) for b in pops[i + 1:] ] return float(1 - np.mean(t_within) / np.mean(t_between)) def _pairwise_coalescence_time(self, pop_i: str, pop_j: str) -> float: """ Expected coalescence time of two lineages, one sampled in ``pop_i`` and one in ``pop_j`` (or both in the same population when ``pop_i == pop_j``), under this demography and coalescent model. Computed from a two-lineage sub-coalescent, so it is independent of the configured sample sizes and number of loci. :param pop_i: Name of the first population. :param pop_j: Name of the second population. :return: Expected pairwise coalescence time ``T_{ij}``. """ pops = self.demography.pop_names for p in (pop_i, pop_j): if p not in pops: raise ValueError(f"Unknown population {p!r}; available: {pops}.") if pop_i == pop_j: counts = {p: (2 if p == pop_i else 0) for p in pops} else: counts = {p: (1 if p in (pop_i, pop_j) else 0) for p in pops} return Coalescent( n=counts, demography=self.demography, model=self.model, end_time=self.end_time ).tree_height.mean
[docs] def f2(self, pop_0: str, pop_1: str) -> float: r""" Patterson's :math:`f_2(A, B) = \mathbb{E}[(p_A - p_B)^2]`, the branch (coalescence-time) version :math:`f_2 = 2 T_{AB} - T_{AA} - T_{BB}` in terms of pairwise coalescence times (matching ``tskit``'s branch-mode ``f2``). Measures the amount of drift separating the two populations. :param pop_0: Name of population ``A``. :param pop_1: Name of population ``B``. :return: :math:`f_2(A, B)`. """ t = self._pairwise_coalescence_time return float(2 * t(pop_0, pop_1) - t(pop_0, pop_0) - t(pop_1, pop_1))
[docs] def f3(self, pop_target: str, pop_0: str, pop_1: str) -> float: r""" Patterson's :math:`f_3(C; A, B) = \mathbb{E}[(p_C - p_A)(p_C - p_B)]`, in branch (coalescence-time) form :math:`f_3 = T_{CA} + T_{CB} - T_{AB} - T_{CC}` (matching ``tskit``'s branch-mode ``f3``). A significantly negative value is evidence that the target population ``C`` is admixed between ``A`` and ``B``. :param pop_target: Name of the (potentially admixed) target population ``C``. :param pop_0: Name of source population ``A``. :param pop_1: Name of source population ``B``. :return: :math:`f_3(C; A, B)`. """ t = self._pairwise_coalescence_time return float(t(pop_target, pop_0) + t(pop_target, pop_1) - t(pop_0, pop_1) - t(pop_target, pop_target))
[docs] def f4(self, pop_0: str, pop_1: str, pop_2: str, pop_3: str) -> float: r""" Patterson's :math:`f_4(A, B; C, D) = \mathbb{E}[(p_A - p_B)(p_C - p_D)]`, in branch (coalescence-time) form :math:`f_4 = T_{AD} + T_{BC} - T_{AC} - T_{BD}` (matching ``tskit``'s branch-mode ``f4``). Used to test treeness and detect gene flow between the two population pairs. :param pop_0: Name of population ``A``. :param pop_1: Name of population ``B``. :param pop_2: Name of population ``C``. :param pop_3: Name of population ``D``. :return: :math:`f_4(A, B; C, D)`. """ t = self._pairwise_coalescence_time return float(t(pop_0, pop_3) + t(pop_1, pop_2) - t(pop_0, pop_2) - t(pop_1, pop_3))
def _get_dist(self, k: int, rewards: Iterable[Reward] = None) -> PhaseTypeDistribution: """ Get the kth-order phase-type distribution with state space inferred from the rewards. The returned phase-type distribution is configured with the unit reward. :param k: Order of the moment. :param rewards: Sequence of k rewards. By default, tree height rewards are used. :return: Distribution. """ if rewards is None: rewards = [TreeHeightReward()] * k # only route to the (expensive) joint state space when a reward requires it; then all rewards must support it if Reward.requires_joint_state_space(rewards): if not Reward.support(JointBlockCountingStateSpace, rewards): raise ValueError( "The given rewards are not jointly compatible with any single state space: " f"{[r.__class__.__name__ for r in rewards]}. A joint-SFS reward can only be combined with " "rewards that also support the joint state space." ) state_space = self.joint_block_counting_state_space elif Reward.support(LineageCountingStateSpace, rewards): state_space = self.lineage_counting_state_space else: state_space = self.block_counting_state_space return PhaseTypeDistribution( reward=UnitReward(), tree_height=self.tree_height, state_space=state_space, demography=self.demography )
[docs] @_make_hashable @cache def moment( self, k: int = 1, rewards: Sequence[Reward] = None, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> float: """ Get the kth (non-central) moment using the specified rewards and state space. :param k: The order of the moment :param rewards: Sequence of k rewards. By default, tree height rewards are used. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: The kth moment """ return self._get_dist(k, rewards).moment( k=k, rewards=rewards, start_time=start_time, end_time=end_time, center=center, permute=permute )
def _sample( self, n_samples: int, rewards: Sequence[Reward] = None, record_visits: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Generate samples from the mean reward distribution by simulating trajectories. :param n_samples: Number of trajectories to simulate. :param rewards: Rewards to sample from. Default is the tree height reward. :param record_visits: Whether to record which states were visited during the sampling. :return: Array of sampled rewards of size (n_samples, len(rewards)), and optionally an array of probabilities of visiting each state. """ return self._get_dist(k=1, rewards=rewards)._sample( n_samples=n_samples, rewards=rewards, record_visits=record_visits ) def _raw_moment( self, k: int, rewards: Sequence[Reward] = None, start_time: float = None, end_time: float = None ) -> float: """ Get the kth raw moment using the specified rewards and state space. :param k: The order of the moment :param rewards: Sequence of k rewards. By default, tree height rewards are used. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :return: The kth raw moment """ return self.moment( k=k, rewards=rewards, start_time=start_time, end_time=end_time, center=False, permute=False )
[docs] def accumulate( self, k: int, end_times: Iterable[float], rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Accumulate moments at different times. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :return: Accumulation of moments. """ return self._get_dist(k, rewards).accumulate( k=k, end_times=end_times, rewards=rewards, center=center, permute=permute )
[docs] def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = False, label: str = None, title: str = None ) -> 'plt.Axes': """ Plot the accumulation of moments. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :param ax: Axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ self._get_dist(k, rewards).plot_accumulation( k=k, end_times=end_times, rewards=rewards, center=center, permute=permute, ax=ax, show=show, file=file, clear=clear, label=label, title=title )
[docs] def drop_cache(self): """ Drop state space cache. """ self.lineage_counting_state_space.drop_cache() self.block_counting_state_space.drop_cache()
def __setstate__(self, state: dict): """ Restore the state of the object from a serialized state. :param state: State. """ self.__dict__.update(state) def __getstate__(self) -> dict: """ Get the state of the object for serialization. :return: State. """ # create deep copy of object without causing infinite recursion other = copy.deepcopy(self.__dict__) if 'lineage_counting_state_space' in other: other['lineage_counting_state_space'].drop_cache() if 'block_counting_state_space' in other: other['block_counting_state_space'].drop_cache() return other
[docs] def to_json(self) -> str: """ Serialize to JSON. Drop cache before serializing. :return: JSON string. """ # copy object to avoid modifying the original other = copy.deepcopy(self) # drop cache other.drop_cache() return super(self.__class__, other).to_json()
[docs] def to_msprime( self, num_replicates: int = 10000, n_threads: int = 10, parallelize: bool = True, record_migration: bool = False, simulate_mutations: bool = False, mutation_rate: float = None, seed: int = None ) -> 'MsprimeCoalescent': """ Convert to msprime coalescent. :param num_replicates: Number of replicates. :param n_threads: Number of threads. :param parallelize: Whether to parallelize. :param record_migration: Whether to record migrations which is necessary to calculate statistics per deme. :param simulate_mutations: Whether to simulate mutations. :param mutation_rate: Mutation rate. :param seed: Random seed. :return: msprime coalescent. """ if self.start_time != 0: self._logger.warning("Non-zero start times are not supported by MsprimeCoalescent.") return MsprimeCoalescent( n=self.lineage_config, demography=self.demography, model=self.model, loci=self.locus_config, recombination_rate=self.locus_config.recombination_rate, mutation_rate=mutation_rate, end_time=self.end_time, num_replicates=num_replicates, n_threads=n_threads, parallelize=parallelize, record_migration=record_migration, simulate_mutations=simulate_mutations, seed=seed )
class EmpiricalTwoLocusSFSDistribution: # pragma: no cover """ Empirical (msprime-based) two-locus SFS, exposing the same ``mean`` interface as :class:`TwoLocusSFSDistribution` (a :class:`~phasegen.spectrum.TwoLocusSFS`) so the two can be compared by :class:`~phasegen.comparison.Comparison`. """ def __init__(self, mean: np.ndarray): """ :param mean: The simulated mean two-locus SFS array. """ self._mean = np.asarray(mean) @property def mean(self) -> TwoLocusSFS: """Mean two-locus SFS.""" return TwoLocusSFS(self._mean) class MsprimeCoalescent(AbstractCoalescent): """ Empirical coalescent distribution based on `msprime` simulations. This is used for testing purposes. Note that the results are stochastic. """ def __init__( self, n: int | Dict[str, int] | List[int] | LineageConfig, demography: Demography = None, model: CoalescentModel = StandardCoalescent(), loci: int | LocusConfig = 1, recombination_rate: float = None, mutation_rate: float = None, end_time: float = None, num_replicates: int = 10000, n_threads: int = 100, parallelize: bool = True, record_migration: bool = False, simulate_mutations: bool = False, seed: int = None ): """ Simulate data using msprime. :param n: Number of Lineages. :param demography: Demography. :param model: Coalescent model. :param loci: Number of loci or locus configuration. :param recombination_rate: Recombination rate. :param mutation_rate: Mutation rate. :param end_time: Time when to end the simulation. :param num_replicates: Number of replicates. :param n_threads: Number of threads. :param parallelize: Whether to parallelize. :param record_migration: Whether to record migrations which is necessary to calculate statistics per deme. :param simulate_mutations: Whether to simulate mutations. :param seed: Random seed. """ super().__init__( n=n, model=model, loci=loci, recombination_rate=recombination_rate, demography=demography, end_time=end_time ) if mutation_rate is not None and not simulate_mutations: self._logger.warning("Mutation rate is set but mutations are not simulated.") #: Site frequency spectrum counts per locus, deme and replicate. self.sfs_lengths: np.ndarray | None = None #: Total branch lengths per locus, deme and replicate. self.total_branch_lengths: np.ndarray | None = None #: Tree heights per locus, deme and replicate. self.heights: np.ndarray | None = None #: Mutations per locus, deme and replicate. self.mutations: np.ndarray | None = None #: Joint SFS (non-central) moments per descendant configuration, of orders 1, ..., ``_jsfs_max_order``. self.jsfs_moments: np.ndarray | None = None #: Number of replicates. self.num_replicates: int = num_replicates #: Mutation rate. self.mutation_rate: float = mutation_rate #: Number of threads. self.n_threads: int = n_threads #: Whether to parallelize computations. self.parallelize: bool = parallelize #: Whether to record migrations. self.record_migration: bool = record_migration #: Whether to simulate mutations. self.simulate_mutations: bool = simulate_mutations #: Random seed. self.seed: int = seed def get_coalescent_model(self) -> 'msprime.AncestryModel': """ Get the coalescent model. :return: msprime coalescent model. """ import msprime as ms if isinstance(self.model, StandardCoalescent): return ms.StandardCoalescent() if isinstance(self.model, BetaCoalescent): return ms.BetaCoalescent(alpha=self.model.alpha) if isinstance(self.model, DiracCoalescent): return ms.DiracCoalescent(psi=self.model.psi, c=self.model.c) @cache def simulate(self): """ Simulate data using msprime. """ # number of replicates for one thread num_replicates = self.num_replicates // self.n_threads samples = self.lineage_config.lineage_dict demography = self.demography.to_msprime() model = self.get_coalescent_model() end_time = self.end_time n_pops = self.demography.n_pops sample_size = self.lineage_config.n # joint SFS is accumulated from the same trees, but only for multi-population, single-locus scenarios where # it is meaningful (the descendant configuration is by deme of origin) compute_jsfs = self.lineage_config.n_pops > 1 and self.locus_config.n == 1 jsfs_max_order = self._jsfs_max_order jsfs_shape = tuple(int(s) + 1 for s in self.lineage_config.lineages) name_to_index = {name: i for i, name in enumerate(self.demography.pop_names)} n_total = num_replicates * self.n_threads def simulate_batch(seed: Optional[int]) -> dict: """ Simulate statistics. :param seed: Random seed. :return: Statistics. """ import msprime as ms import tskit # simulate trees g: Generator = ms.sim_ancestry( sequence_length=self.locus_config.n, recombination_rate=self.locus_config.recombination_rate, samples=samples, num_replicates=num_replicates, record_migrations=self.record_migration, demography=demography, model=model, ploidy=1, end_time=end_time, random_seed=seed ) # initialize variables heights = np.zeros((self.locus_config.n, n_pops, num_replicates), dtype=float) total_branch_lengths = np.zeros((self.locus_config.n, n_pops, num_replicates), dtype=float) sfs = np.zeros((self.locus_config.n, n_pops, num_replicates, sample_size + 1), dtype=float) mutations = np.zeros((self.locus_config.n, n_pops, num_replicates, sample_size + 1), dtype=int) # joint SFS moment accumulator (non-central moments of orders 1, ..., jsfs_max_order) jsfs_acc = np.zeros((jsfs_max_order,) + jsfs_shape) # iterate over trees and compute statistics ts: tskit.TreeSequence for i, ts in enumerate(g): # map each sample to the index of its sampling population (deme of origin) for the joint SFS if compute_jsfs: pop_of_leaf = { u: name_to_index[ts.population(ts.node(u).population).metadata['name']] for u in ts.samples() } tree: tskit.Tree for j, tree in enumerate(self._expand_trees(ts)): # TODO record_migration only appears to work for relatively simple scenarios if self.record_migration: lineages = np.array(list(samples.values())) t_coal = ts.tables.nodes.time[sample_size:] node = sample_size - 1 t_migration = ts.migrations_time i_migration = 0 time = 0 # population state per leave pop_states = {n: tree.population(n) for n in range(sample_size)} # iterate over coalescence events for coal_time in t_coal: # iterate over migration events within this coalescence event while i_migration < len(t_migration) and time < t_migration[i_migration] <= coal_time: delta = t_migration[i_migration] - time # update statistics heights[j, :, i] += delta * lineages / sum(lineages) total_branch_lengths[j, :, i] += delta * lineages for n, pop in pop_states.items(): sfs[j, pop, i, tree.get_num_leaves(n)] += delta # update lineages with migrations lineages[ts.migrations_source[i_migration]] -= 1 lineages[ts.migrations_dest[i_migration]] += 1 pop_states[ts.migrations_node[i_migration]] = ts.migrations_dest[i_migration] i_migration += 1 time += delta # remaining time to next coalescence event delta = coal_time - time # update statistics heights[j, :, i] += delta * lineages / sum(lineages) total_branch_lengths[j, :, i] += delta * lineages for n, pop in pop_states.items(): sfs[j, pop, i, tree.get_num_leaves(n)] += delta # reduce by number of coalesced lineages lineages[tree.population(node + 1)] -= len(tree.get_children(node + 1)) - 1 # delete children from pop_states [pop_states.__delitem__(n) for n in tree.get_children(node + 1)] # add parent to pop_states pop_states[node + 1] = tree.population(node + 1) time += delta node += 1 else: heights[j, 0, i] = tree.time(tree.roots[0]) total_branch_lengths[j, 0, i] = tree.total_branch_length for node in tree.nodes(): t = tree.get_branch_length(node) n = tree.get_num_leaves(node) sfs[j, 0, i, n] += t # accumulate the joint SFS from the same tree (single locus only) if compute_jsfs and j == 0: jsfs_rep = np.zeros(jsfs_shape) for node in tree.nodes(): # the root subtends all samples (monomorphic) and is skipped if tree.parent(node) == -1: continue # count descendant samples by population (deme of origin) vec = [0] * len(jsfs_shape) for leaf in tree.leaves(node): vec[pop_of_leaf[leaf]] += 1 if sum(vec) > 0: jsfs_rep[tuple(vec)] += tree.get_branch_length(node) for order in range(jsfs_max_order): jsfs_acc[order] += jsfs_rep ** (order + 1) # simulate mutations if specified if self.simulate_mutations: mts = ms.sim_mutations(ts, rate=self.mutation_rate, random_seed=seed) tree = next(mts.trees()) for node in mts.mutations_node: mutations[0, 0, i, tree.get_num_leaves(node)] += 1 return dict( main=np.concatenate([[heights.T], [total_branch_lengths.T], sfs.T, mutations.T]), jsfs=jsfs_acc ) # parallelize over threads batches = parallelize( func=simulate_batch, data=[self.seed + i if self.seed is not None else None for i in range(self.n_threads)], parallelize=self.parallelize, batch_size=num_replicates, desc="Simulating trees", dtype=object ) # combine the per-replicate statistics across threads res = np.hstack([b['main'] for b in batches]) # store results self.heights = res[0].T self.total_branch_lengths = res[1].T self.sfs_lengths = res[2:sample_size + 3].T self.mutations = res[sample_size + 3:].T.astype(int) # combine the joint SFS moments (summed over replicates) across threads and normalize to moments self.jsfs_moments = np.sum([b['jsfs'] for b in batches], axis=0) / n_total if compute_jsfs else None @staticmethod def _expand_trees(ts: 'tskit.TreeSequence') -> Iterator['tskit.Tree']: """ Expand tree sequence to `n` trees where `n` is the number of loci. :param ts: Tree sequence. :return: List of trees. """ for tree in ts.trees(): for _ in range(int(tree.length)): yield tree def _get_cached_times(self) -> np.ndarray: """ Get cached times. """ t_max = self.heights.sum(axis=1).max() return np.linspace(0, t_max, 100) def touch(self, **kwargs: dict): """ Touch cached properties. :param kwargs: Additional keyword arguments. """ self.simulate() t = self._get_cached_times() self.tree_height.touch(t) self.total_tree_height.touch(t) self.total_branch_length.touch(t) self.sfs.touch(t) self.fsfs.touch(t) # cache the joint SFS distribution (its moments were already accumulated by simulate() above) for # multi-population, single-locus scenarios, so it is serialized along with the comparison if self.lineage_config.n_pops > 1 and self.locus_config.n == 1: # noinspection PyStatementEffect self.jsfs def drop(self): """ Drop simulated data. """ self.heights = None self.total_branch_lengths = None self.sfs_lengths = None self.mutations = None # the moments are retained by the cached jsfs distribution (referenced before drop), so this only removes # the duplicate reference held on the coalescent self.jsfs_moments = None self.tree_height.drop() self.total_tree_height.drop() self.total_branch_length.drop() self.sfs.drop() self.fsfs.drop() # caused problems when serializing self.demography = None @cached_property def tree_height(self) -> EmpiricalPhaseTypeDistribution: """ Tree height distribution. """ self.simulate() return EmpiricalPhaseTypeDistribution( self.heights, pops=self.demography.pop_names, locus_agg=lambda x: x.max(axis=0) ) @cached_property def total_tree_height(self) -> EmpiricalPhaseTypeDistribution: """ Total tree height distribution. """ self.simulate() return EmpiricalPhaseTypeDistribution(self.heights, pops=self.demography.pop_names) @cached_property def total_branch_length(self) -> EmpiricalPhaseTypeDistribution: """ Total branch length distribution. """ self.simulate() return EmpiricalPhaseTypeDistribution(self.total_branch_lengths, pops=self.demography.pop_names) @cached_property def sfs(self) -> EmpiricalPhaseTypeSFSDistribution: """ Unfolded site-frequency spectrum distribution. """ self.simulate() return EmpiricalPhaseTypeSFSDistribution( branch_lengths=self.sfs_lengths, mutations=self.mutations.T[1:-1].T, pops=self.demography.pop_names, sfs_dist=UnfoldedSFSDistribution ) @cached_property def fsfs(self) -> EmpiricalPhaseTypeSFSDistribution: """ Folded site-frequency spectrum distribution. """ self.simulate() mid = (self.lineage_config.n + 1) // 2 # fold SFS branch lengths lengths = self.sfs_lengths.copy().T lengths[:mid] += lengths[-mid:][::-1] lengths[-mid:] = 0 # fold SFS mutations mutations = self.mutations.copy().T mutations[:mid] += mutations[-mid:][::-1] mutations = mutations[1:self.lineage_config.n // 2 + 1] return EmpiricalPhaseTypeSFSDistribution( branch_lengths=lengths.T, mutations=mutations.T, pops=self.demography.pop_names, sfs_dist=FoldedSFSDistribution ) #: Highest moment order computed for the empirical joint SFS ground truth. _jsfs_max_order: int = 3 @cached_property def jsfs(self) -> 'EmpiricalJointSFSDistribution': """ Joint (multi-population) site-frequency spectrum ground truth, accumulated from the same simulated trees as the other statistics (see :meth:`simulate`). Returns an :class:`EmpiricalJointSFSDistribution` exposing ``mean``, ``m2``, ``m3`` and ``var`` as arrays of shape ``(n_0 + 1, ..., n_{P-1} + 1)``, matching :class:`JointSFSDistribution`. The descendant configuration of a branch is the number of its sample descendants from each population (its deme of origin). Only available for multi-population, single-locus scenarios. """ self.simulate() if self.jsfs_moments is None: raise NotImplementedError( "The joint SFS is only available for multi-population, single-locus scenarios." ) return EmpiricalJointSFSDistribution(moments=self.jsfs_moments) @cached_property def sfs2(self) -> 'EmpiricalTwoLocusSFSDistribution': """ Two-locus SFS ground truth, simulated with msprime: two sites at recombination distance ``r`` (the two loci), the per-bin branch-length cross product averaged over replicates. Only available for two-locus, single-locus- sample scenarios. Returns an :class:`EmpiricalTwoLocusSFSDistribution` exposing ``mean`` as a :class:`~phasegen.spectrum.TwoLocusSFS`, matching :class:`TwoLocusSFSDistribution`. """ import msprime as ms if self.locus_config.n != 2: raise NotImplementedError("The two-locus SFS is only available for two-locus scenarios.") n = self.lineage_config.n demography = self.demography.to_msprime() model = self.get_coalescent_model() out = np.zeros((n + 1, n + 1)) for ts in ms.sim_ancestry( samples=self.lineage_config.lineage_dict, sequence_length=2, recombination_rate=self.locus_config.recombination_rate, demography=demography, model=model, ploidy=1, num_replicates=self.num_replicates, random_seed=self.seed, ): t0, t1 = ts.at(0.5), ts.at(1.5) left = np.zeros(n + 1) right = np.zeros(n + 1) for nd in t0.nodes(): if t0.parent(nd) != -1: left[t0.num_samples(nd)] += t0.branch_length(nd) for nd in t1.nodes(): if t1.parent(nd) != -1: right[t1.num_samples(nd)] += t1.branch_length(nd) out += np.outer(left, right) return EmpiricalTwoLocusSFSDistribution(out / self.num_replicates) @cached_property def fst(self) -> float: r""" Hudson's :math:`F_{ST}` ground truth, simulated with msprime: ``1 - mean within-population branch diversity / mean between-population branch divergence``, averaged over replicate trees. Requires at least two populations, each with at least two sampled lineages. Matches :meth:`Coalescent.fst`. """ import msprime as ms pops = self.demography.pop_names if len(pops) < 2: raise ValueError(f"F_ST requires at least two populations (got {len(pops)}).") within = np.zeros(self.num_replicates) between = np.zeros(self.num_replicates) for k, ts in enumerate(ms.sim_ancestry( samples=self.lineage_config.lineage_dict, sequence_length=1, demography=self.demography.to_msprime(), model=self.get_coalescent_model(), ploidy=1, num_replicates=self.num_replicates, random_seed=self.seed, )): sample_sets = [ts.samples(population=i) for i in range(len(pops))] # within-population diversity (only populations with at least two samples are informative) w = [ts.diversity(s, mode='branch') for s in sample_sets if len(s) >= 2] # between-population divergence over distinct population pairs b = [ts.divergence([sample_sets[i], sample_sets[j]], mode='branch') for i in range(len(pops)) for j in range(i + 1, len(pops)) if len(sample_sets[i]) and len(sample_sets[j])] within[k] = np.mean(w) between[k] = np.mean(b) return float(1 - within.mean() / between.mean()) def _branch_f_statistic(self, kind: str, pops: List[str]) -> float: """ msprime branch-mode Patterson f-statistic ground truth (``f2``/``f3``/``f4``) over the given populations, averaged over replicate trees (tskit branch mode uses the same 2x pairwise-coalescence convention as the analytical :class:`Coalescent` f-statistics). """ import msprime as ms names = self.demography.pop_names for pop in pops: if pop not in names: raise ValueError(f"Unknown population '{pop}'. Available populations: {names}.") idx = [names.index(pop) for pop in pops] values = np.zeros(self.num_replicates) for k, ts in enumerate(ms.sim_ancestry( samples=self.lineage_config.lineage_dict, sequence_length=1, demography=self.demography.to_msprime(), model=self.get_coalescent_model(), ploidy=1, num_replicates=self.num_replicates, random_seed=self.seed, )): sample_sets = [ts.samples(population=i) for i in idx] values[k] = getattr(ts, kind)(sample_sets, mode='branch') return float(values.mean()) def f2(self, pop_0: str, pop_1: str) -> float: """msprime branch-mode ``f2`` ground truth. Matches :meth:`Coalescent.f2`.""" return self._branch_f_statistic('f2', [pop_0, pop_1]) def f3(self, pop_target: str, pop_0: str, pop_1: str) -> float: """msprime branch-mode ``f3`` ground truth. Matches :meth:`Coalescent.f3`.""" return self._branch_f_statistic('f3', [pop_target, pop_0, pop_1]) def f4(self, pop_0: str, pop_1: str, pop_2: str, pop_3: str) -> float: """msprime branch-mode ``f4`` ground truth. Matches :meth:`Coalescent.f4`.""" return self._branch_f_statistic('f4', [pop_0, pop_1, pop_2, pop_3]) def to_phasegen(self) -> Coalescent: """ Convert to native phasegen coalescent. :return: phasegen coalescent. """ return Coalescent( n=self.lineage_config, model=self.model, demography=self.demography, loci=self.locus_config, recombination_rate=self.locus_config.recombination_rate, end_time=self.end_time )