Source code for phasegen.distributions.spectra

"""Site-frequency-spectrum distributions (SFS, folded, joint, two-locus)."""

import itertools
import logging
from abc import ABC, abstractmethod
from ..caching import cached_property, cache
from typing import List, Tuple, Iterable, Iterator, Optional, Sequence, Set, TYPE_CHECKING
import numpy as np
from ..demography import Demography
from ..expm import Backend
from ..rewards import Reward, TreeHeightReward, UnfoldedSFSReward, UnitReward, CombinedReward, FoldedSFSReward, SFSReward, JointSFSReward, TwoLocusSFSReward
from ..settings import Settings
from ..spectrum import SFS, SFS2, JointSFS, TwoLocusSFS
from ..state_space import BlockCountingStateSpace, StateSpace, JointBlockCountingStateSpace, TwoLocusBlockCountingStateSpace
from ..utils import multiset_permutations

from ._common import _make_hashable
from .phase_type import PhaseTypeDistribution, TreeHeightDistribution

if TYPE_CHECKING:
    from matplotlib import pyplot as plt

expm = Backend.expm
logger = logging.getLogger('phasegen')


class SFSDistribution(PhaseTypeDistribution, ABC):
    """
    Base class for site-frequency spectrum distributions.
    """

    def __init__(
            self,
            state_space: BlockCountingStateSpace,
            tree_height: TreeHeightDistribution,
            demography: Demography,
            reward: Reward = None
    ):
        """
        Initialize the distribution.

        :param state_space: Block-counting state space.
        :param tree_height: The tree height distribution.
        :param demography: The demography.
        :param reward: The reward to multiply the SFS reward with. By default, the unit reward is used, which
            has no effect.
        """
        if reward is None:
            reward = UnitReward()

        super().__init__(
            state_space=state_space,
            tree_height=tree_height,
            demography=demography,
            reward=reward
        )

        #: Generated probability mass by iterator returned from :meth:`get_mutation_configs`.
        self.generated_mass = 0

    @abstractmethod
    def _get_sfs_reward(self, i: int) -> SFSReward:
        """
        Get the reward for the ith site-frequency count.

        :param i: The ith site-frequency count.
        :return: The reward.
        """
        pass

    @abstractmethod
    def _get_indices(self) -> np.ndarray:
        """
        Get the indices for the site-frequency spectrum.

        :return: The indices.
        """
        pass

    @staticmethod
    @abstractmethod
    def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]:
        """
        Get all possible mutational configurations for a given number of mutations.

        :param n: The number of lineages.
        :param k: The number of mutations.
        :return: An iterator over all possible mutational configurations.
        """
        pass

    @_make_hashable
    @cache
    def moment(
            self,
            k: int,
            rewards: Sequence[SFSReward] = None,
            start_time: float = None,
            end_time: float = None,
            center: bool = True,
            permute: bool = True
    ) -> SFS:
        """
        Get the kth moments of the site-frequency spectrum.

        :param k: The order of the moment
        :param rewards: Sequence of k rewards
        :param start_time: Time when to start accumulation of moments. By default, the start time specified when
            initializing the distribution.
        :param end_time: Time when to end accumulation of moments. By default, either the end time specified when
            initializing the distribution or the time until almost sure absorption.
        :param center: Whether to center the moment around the mean.
        :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``,
            which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on
            the order of rewards.
        :return: A site-frequency spectrum of kth order moments.
        """
        if rewards is None:
            rewards = (self.reward,) * k

        # batched mean: every bin's mean is ``occupation . r_bin`` with the same occupation-time vector, so the whole
        # spectrum is one contraction instead of a per-bin solve. This is the closed form's spectrum path (it shares
        # the transient solve across bins); only for the plain mean to absorption (k=1, default reward, no custom
        # accumulation window) and when flattening does not apply (flattening reduces the state space and wins).
        # Other cases fall through to the per-bin path.
        if (
                Settings.closed_form_last_epoch and
                not self._flattening_applies(k) and
                k == 1 and
                start_time is None and
                end_time is None and
                self.tree_height.end_time is None and
                rewards == (self.reward,)
        ):
            occupation = self._occupation_times()
            if occupation is not None:
                m, idx_t = occupation
                base = np.asarray(self.reward._get(self.state_space), dtype=float)
                R = np.column_stack([
                    (base * np.asarray(self._get_sfs_reward(i)._get(self.state_space), dtype=float))[idx_t]
                    for i in self._get_indices()
                ])
                moments = m @ R
                return SFS([0] + list(moments) + [0] * (self.lineage_config.n - len(moments)))

        # moment of each SFS bin (serial; performance-critical paths use the batched closed form above)
        moments = np.array([
            self._moment(k, i, rewards, start_time, end_time, center, permute)
            for i in self._get_indices()
        ])

        return SFS([0] + list(moments) + [0] * (self.lineage_config.n - len(moments)))

    def _moment(
            self,
            k: int,
            i: int,
            rewards: Sequence[SFSReward] = None,
            start_time: float = None,
            end_time: float = None,
            center: bool = True,
            permute: bool = True
    ) -> float:
        """
        Get the kth moment for the ith site-frequency count.

        :param k: The order of the moment
        :param i: The ith site-frequency count
        :param rewards: Sequence of k rewards
        :param start_time: Time when to start accumulation of moments. By default, the start time specified when
            initializing the distribution.
        :param end_time: Time when to end accumulation of moments. By default, either the end time specified when
            initializing the distribution or the time until almost sure absorption.
        :param center: Whether to center the moment around the mean.
        :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``,
            which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on
            the order of rewards.
        :return: The kth SFS (cross)-moment at the ith site-frequency count
        """
        return PhaseTypeDistribution.moment(
            self,
            k=k,
            rewards=tuple([CombinedReward([r, self._get_sfs_reward(i)]) for r in rewards]),
            start_time=start_time,
            end_time=end_time,
            center=center,
            permute=permute
        )

    def accumulate(
            self,
            k: int,
            end_times: Iterable[float],
            rewards: Sequence[Reward] = None,
            center: bool = True,
            permute: bool = True
    ) -> np.ndarray:
        """
        Evaluate the kth (non-central) moments for site-frequency spectrum at different end times.

        :param k: The order of the moment.
        :param end_times: Times or time when to evaluate the moment.
        :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution.
        :param center: Whether to center the moment around the mean.
        :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``,
            which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on
            the order of rewards.
        :return: Array of moments accumulated at the specified times, one for each site-frequency count.
        """
        k = int(k)
        indices = self._get_indices()
        end_times = np.array(list(end_times))

        accumulation = np.array([
            self.get_accumulation(k, i, end_times, rewards)
            for i in indices
        ])

        # pad with zeros
        return np.concatenate([
            np.zeros((1, len(end_times))),
            accumulation,
            np.zeros((self.lineage_config.n - len(indices), len(end_times)))
        ])

    def plot_accumulation(
            self,
            k: int = 1,
            end_times: Iterable[float] = None,
            rewards: Sequence[Reward] = None,
            center: bool = True,
            permute: bool = True,
            ax: 'plt.Axes' = None,
            show: bool = True,
            file: str = None,
            clear: bool = True,
            label: str = None,
            title: str = None
    ) -> 'plt.Axes':
        """
        Plot accumulation of (non-central) SFS moments at different times.

        .. note:: This is different from a CDF, as it shows the accumulation of moments rather than the probability
            of having reached absorption at a certain time.

        :param k: The order of the moment.
        :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and
            the 99th percentile.
        :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution.
        :param center: Whether to center the moment around the mean.
        :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``,
            which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on
            the order of rewards.
        :param ax: The axes to plot on.
        :param show: Whether to show the plot.
        :param file: File to save the plot to.
        :param clear: Whether to clear the plot before plotting.
        :param label: Label for the plot.
        :param title: Title of the plot.
        :return: Axes.
        """
        import matplotlib.pyplot as plt
        from ..visualization import Visualization

        k = int(k)

        if ax is None:
            ax = plt.gca()

        if end_times is None:
            end_times = np.linspace(0, self.tree_height.quantile(0.99), 200)

        if rewards is None:
            rewards = (self.reward,) * k

        if title is None:
            title = (f"SFS Moment accumulation "
                     f"({', '.join(r.__class__.__name__.replace('Reward', '') for r in rewards)})")

        # get accumulation of moments
        accumulation = self.accumulate(k, end_times, rewards, center, permute)

        for i, acc in zip(self._get_indices(), accumulation[1: -1]):
            Visualization.plot(
                ax=ax,
                x=end_times,
                y=acc,
                xlabel='t',
                ylabel='moment',
                label=f'{i}',
                file=file,
                show=i == self._get_indices()[-1] and show,
                clear=clear,
                title=title
            )

        return ax

    def get_accumulation(
            self,
            k: int,
            i: int,
            end_times: Iterable[float] | float,
            rewards: Sequence[SFSReward] = None,
            center: bool = True,
            permute: bool = True
    ) -> np.ndarray | float:
        """
        Get accumulation of moments for the ith site-frequency count.

        :param k: The order of the moment
        :param i: The ith site-frequency count.
        :param end_times: Times or time when to evaluate the moment.
        :param rewards: Sequence of k rewards.
        :param center: Whether to center the moment around the mean.
        :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``,
            which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on
            the order of rewards.
        :return: The kth SFS (cross)-moment accumulations at the ith site-frequency count
        """
        if rewards is None:
            rewards = [self.reward] * k

        return super().accumulate(
            k=k,
            end_times=end_times,
            rewards=tuple([CombinedReward([r, self._get_sfs_reward(i)]) for r in rewards]),
            center=center,
            permute=permute
        )

    def _cov_batched(self) -> Optional[SFS2]:
        """
        Batched 2-SFS: all ``O(n^2)`` bin pairs share one two-point occupation operator ``K`` (see
        :meth:`_two_point_occupation`), so the whole covariance is ``cov = R^T (K + K^T) R - outer(mean)`` via a
        single contraction over the stacked bin rewards instead of a cross-moment per pair.

        :return: The covariance, or ``None`` when not applicable (closed form disabled, explicit end time, or
            absorption not almost sure) so the caller falls back to the per-pair path.
        """
        if not Settings.closed_form_last_epoch:
            return None

        two_point = self._two_point_occupation()
        if two_point is None:
            return None

        K, idx_t = two_point
        ss = self.state_space
        base = np.asarray(self.reward._get(ss), dtype=float)
        indices = self._get_indices()
        R = np.column_stack([
            (base * np.asarray(self._get_sfs_reward(i)._get(ss), dtype=float))[idx_t] for i in indices
        ])

        sfs_matrix = R.T @ K @ R                       # R^T K R (one ordering)
        self._logger.debug("sfs.cov: centering with the outer product of bin means")
        mean = np.asarray(self.mean.data)[indices]
        cov = (sfs_matrix + sfs_matrix.T) - np.outer(mean, mean)

        out = np.zeros((self.lineage_config.n + 1, self.lineage_config.n + 1))
        for a, ia in enumerate(indices):
            out[ia, indices] = cov[a]
        return SFS2(out)

    @cached_property
    def cov(self) -> SFS2:
        """
        Covariance matrix across site-frequency counts.
        """
        batched = self._cov_batched()
        if batched is not None:
            self._logger.debug("sfs.cov: batched (shared two-point occupation)")
            return batched

        # create list of arguments for each combination of i, j
        indices = [(i, j) for i in self._get_indices() for j in self._get_indices()]

        self._logger.debug("sfs.cov: per-pair matrix exponential over %d bin pairs", len(indices))

        # cross-moment of each bin pair (serial)
        sfs_results = [
            PhaseTypeDistribution.moment(self, k=2, permute=False, center=False, rewards=(
                CombinedReward([self.reward, self._get_sfs_reward(i)]),
                CombinedReward([self.reward, self._get_sfs_reward(j)])
            ))
            for i, j in indices
        ]

        # re-structure the results to a matrix form
        sfs = np.zeros((self.lineage_config.n + 1, self.lineage_config.n + 1))
        for ((i, j), result) in zip(indices, sfs_results):
            sfs[i, j] = result

        # get matrix of marginal moments
        m2 = np.outer(self.mean.data, self.mean.data)

        # calculate covariances
        cov = (sfs + sfs.T) / 2 - m2

        return SFS2(cov)

    def get_cov(self, i: int, j: int) -> float:
        """
        Get the covariance between the ith and jth site-frequency.

        :param i: The ith frequency count
        :param j: The jth frequency count
        :return: covariance
        """
        if i in (0, self.lineage_config.n) or j in (0, self.lineage_config.n):
            return 0

        return super().moment(
            k=2,
            rewards=(
                CombinedReward([self.reward, self._get_sfs_reward(i)]),
                CombinedReward([self.reward, self._get_sfs_reward(j)])
            ),
            center=True
        )

    @cached_property
    def corr(self) -> SFS2:
        """
        Correlation matrix across site-frequency counts.
        """
        # get standard deviations
        std = np.sqrt(self.var.data)

        # monomorphic bins have zero variance; the resulting NaNs from dividing by a zero std are expected and
        # replaced with zeros below, so silence the benign divide warning at the source.
        with np.errstate(divide='ignore', invalid='ignore'):
            sfs = SFS2(self.cov.data / np.outer(std, std))

        # replace NaNs with zeros
        sfs.data[np.isnan(sfs.data)] = 0

        return sfs

    def get_corr(self, i: int, j: int) -> float:
        """
        Get the correlation coefficient between the ith and jth site-frequency.

        :param i: The ith frequency count
        :param j: The jth frequency count
        :return: Correlation coefficient
        """
        if i in (0, self.lineage_config.n) or j in (0, self.lineage_config.n):
            return 0

        return self.get_cov(i, j) / (np.sqrt(self.get_cov(i, i)) * np.sqrt(self.get_cov(j, j)))

    @cache
    def _get_P(self, n: int, theta: float) -> Tuple[np.ndarray, np.ndarray]:
        """
        Get transition matrix for mutational configuration probabilities.

        :param n: The number of frequency bins.
        :param theta: The mutation rate.
        :return: Transition matrix and exit vector.
        """
        # get non-absorbing states
        non_absorbing = TreeHeightReward()._get(self.state_space).astype(bool)

        e = self.state_space.e[non_absorbing]
        R = np.array([self._get_sfs_reward(i)._get(self.state_space) for i in range(1, n + 1)])[:, non_absorbing]
        r_total = R.T @ np.ones(n)

        S = self.state_space.S[non_absorbing, :][:, non_absorbing]
        I = np.eye(S.shape[0])

        P_total = np.linalg.inv(I - np.diag(1 / r_total) / theta @ S)
        p_total = (I - P_total) @ e
        P = np.array([P_total @ np.diag(R[i] / r_total) for i in range(n)])

        return P, p_total

    def get_mutation_config(self, config: Sequence[int], theta: float) -> float:
        """
        Get the probabilities of observing the given mutational configurations according to the infinite sites model.

        .. note::
            This currently only works for a single epoch, i.e. a time-homogeneous demography, and recombination is not
            supported.

        :param config: The mutational configuration. A sequence of integers of length n - 1 for unfolded configurations
            and n // 2 for folded configurations, where n is the number of
            lineages. Each element in the sequence is an integer representing the number of mutations
            at each frequency count starting from 1. For example, the unfolded configuration [2, 1, 0] represents two
            singleton, one doubleton and zero tripleton mutations for a sample size of 4 lineages. Similarly, the
            folded configuration [2, 1] represents two singleton or tripleton and one doubleton mutation for the same
            number of lineages.
        :param theta: The mutation rate.
        :return: The probability of observing the given mutational configuration.
        """
        # raise not implemented error if more than one epoch
        if self.demography.has_n_epochs(2):
            raise NotImplementedError("Sampling not implemented for more than one epoch.")

        # make sure theta is non-negative
        if theta < 0:
            raise ValueError("Theta must be greater than or equal to 0.")

        # number of frequency bins
        n = len(self._get_configs(self.lineage_config.n, 0)[0])

        if len(config) != n:
            raise ValueError(
                "The length of the configuration must be equal to the number of frequency bins. "
                f"Expected {n}, got {len(config)}."
            )

        # explicitly convert to tuple of integers
        config = tuple(int(c) for c in config)

        # handle special case when theta = 0
        if theta == 0:
            if sum(config) == 0:
                return 1

            return 0

        # get non-absorbing states
        non_absorbing = TreeHeightReward()._get(self.state_space).astype(bool)

        # number of non-absorbing states
        k = non_absorbing.sum()

        alpha = self.state_space.alpha[non_absorbing]

        P, p_total = self._get_P(n, theta)

        q = list(itertools.chain(*[[i + 1] * j for i, j in enumerate(config)]))

        # iterate over permutations of q
        Q = np.zeros((k, k))
        for p in multiset_permutations(q):
            U = np.eye(k)

            for i in p:
                U @= P[i - 1]

            Q += U

        p = alpha @ Q @ p_total

        return p

    def get_mutation_configs(self, theta: float) -> Iterator[Tuple[Tuple[float, ...], float]]:
        """
        An iterator over the probabilities of observing mutational configurations according to the infinite sites model.
        The order of the mutational configurations generated ascends in the number of mutations observed.
        See :meth:`get_mutation_config` for more information on mutational configurations.

        .. note::
            This currently only works for a single epoch, i.e. a time-homogeneous demography, and recombination is not
            supported. Also note that the number of configurations is infinite, so this iterator will never stop.
            However, depending on the mutation rate, the probability of observing configurations of higher mutation
            counts will decrease over time. You can keep track of the generated probability mass by checking the
            :attr:`~.generated_mass` attribute, which is reset every time this method is called.
            A good approach is thus to keep generating configurations until the generated mass is above a certain
            threshold. More complex demographic models, larger sample sizes, and higher mutation rates all increase
            the number of generated configurations necessary to reach a certain mass.

        Code example:

        ::

            coal = pg.Coalescent(n=5)

            it = coal.sfs.get_mutation_configs(theta=1)

            # continue until generated mass is above 0.8
            samples = list(pg.takewhile_inclusive(lambda _: coal.sfs.generated_mass < 0.8, it))

        :param theta: The mutation rate.
        :return: An iterator over the probabilities of observing mutational configurations.
        """
        # reset generated mass
        self.generated_mass = 0

        # iterate over number of mutations
        i = 0
        while True:
            # iterate over configurations
            for config in self._get_configs(self.lineage_config.n, i):
                p = self.get_mutation_config(config=config, theta=theta)
                self.generated_mass += p
                yield config, p

            # increase counter for number of mutations
            i += 1


class TajimaSFSMixin:
    """
    Mixin providing the branch-length diversity estimators and Tajima's :math:`D` from the site-frequency
    spectrum mean and covariance. Shared by the analytical :class:`UnfoldedSFSDistribution` and the
    simulation-based empirical SFS distribution, so the same statistics can be computed from either source.
    Subclasses supply :meth:`_tajima_n`, :meth:`_tajima_mean` and :meth:`_tajima_cov`.
    """

    def _tajima_n(self) -> int:
        """Number of lineages."""
        raise NotImplementedError

    def _tajima_mean(self) -> np.ndarray:
        """Mean branch length per polymorphic SFS bin (``i = 1 .. n-1``)."""
        raise NotImplementedError

    def _tajima_cov(self) -> np.ndarray:
        """Covariance of the polymorphic SFS bins (``i, j = 1 .. n-1``)."""
        raise NotImplementedError

    @cached_property
    def _tajima_weights(self) -> Tuple[np.ndarray, np.ndarray]:
        """Per-bin weights for the two diversity estimators: pairwise diversity ``pi`` and Watterson's ``theta_W``."""
        n = self._tajima_n()
        i = np.arange(1, n)
        w_pi = 2 * i * (n - i) / (n * (n - 1))
        w_w = np.full(n - 1, 1 / np.sum(1 / i))

        return w_pi, w_w

    @cached_property
    def theta_pi(self) -> float:
        r"""
        Mean pairwise diversity :math:`\pi = \sum_i \frac{2 i (n - i)}{n (n - 1)} \mathbb{E}[L_i]`, the branch-length
        estimator of :math:`\theta` based on the expected number of pairwise differences.
        """
        w_pi, _ = self._tajima_weights

        return float(w_pi @ self._tajima_mean())

    @cached_property
    def theta_w(self) -> float:
        r"""
        Watterson's estimator :math:`\theta_W = L_\text{total} / a_n` with :math:`a_n = \sum_{k=1}^{n-1} 1/k`, the
        branch-length estimator of :math:`\theta` based on the total branch length.
        """
        _, w_w = self._tajima_weights

        return float(w_w @ self._tajima_mean())

    @cached_property
    def tajimas_d(self) -> float:
        r"""
        Tajima's :math:`D` in branch form: :math:`D = (\pi - \theta_W) / \sqrt{c^\top \, \mathrm{Cov}[L] \, c}`
        with weights :math:`c_i = \frac{2 i (n - i)}{n (n - 1)} - 1/a_n`. It is ``0`` under the standard neutral
        constant-size model, negative under population growth (excess of low-frequency variants) and positive under
        contraction. The normalization uses the branch-length covariance rather than the mutation-based variance of
        the classical sample estimator.
        """
        w_pi, w_w = self._tajima_weights
        c = w_pi - w_w

        num = c @ self._tajima_mean()
        var = c @ self._tajima_cov() @ c

        if var <= 0:
            return 0.0

        return float(num / np.sqrt(var))


[docs] class UnfoldedSFSDistribution(SFSDistribution, TajimaSFSMixin): """ Unfolded site-frequency spectrum distribution. """ def _get_sfs_reward(self, i: int) -> UnfoldedSFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ return UnfoldedSFSReward(i) def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ return np.arange(1, self.lineage_config.n) def _tajima_n(self) -> int: return self.lineage_config.n def _tajima_mean(self) -> np.ndarray: n = self.lineage_config.n return np.asarray(self.mean.data)[1:n] def _tajima_cov(self) -> np.ndarray: n = self.lineage_config.n return np.asarray(self.cov.data)[1:n, 1:n] @staticmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ return StateSpace._get_partitions(n=k, k=n - 1)
[docs] class FoldedSFSDistribution(SFSDistribution): """ Folded site-frequency spectrum distribution. """ def _get_sfs_reward(self, i: int) -> FoldedSFSReward: """ Get the reward for the ith site-frequency count. :param i: The ith site-frequency count. :return: The reward. """ return FoldedSFSReward(i) def _get_indices(self) -> np.ndarray: """ Get the indices for the site-frequency spectrum. :return: The indices. """ return np.arange(1, self.lineage_config.n // 2 + 1) @staticmethod def _get_configs(n: int, k: int) -> List[Tuple[int, ...]]: """ Get all possible mutational configurations for a given number of mutations. :param n: The number of lineages. :param k: The number of mutations. :return: An iterator over all possible mutational configurations. """ return StateSpace._get_partitions(n=k, k=n // 2) def _unfold(self, config: Sequence[int]) -> Set[Tuple[int, ...]]: """ Unfold a folded configuration into all possible unfolded configurations. :param config: The folded configuration. A sequence of integers of length n // 2 where n is the number of lineages. :return: The unfolded configurations. """ n = self.lineage_config.n if n // 2 != len(config): raise ValueError("The length of the configuration must equal n // 2 where n is the number of lineages.") if n % 2 == 1: lower_counts = [range(i + 1) for i in config] i_center = len(config) else: lower_counts = [range(i + 1) for i in config[:-1]] + [[config[-1]]] i_center = len(config) - 1 unfolded = [] # iterate over unfolded configurations for lower in itertools.product(*lower_counts): # get higher counts higher = (np.array(config) - np.array(lower))[:i_center][::-1] unfolded += [list(lower) + list(higher)] return set(tuple(u) for u in unfolded)
[docs] class JointSFSDistribution(PhaseTypeDistribution): """ Joint (multi-population) site-frequency spectrum distribution. Moments are returned as a multi-dimensional array of shape ``(n_0 + 1, ..., n_{P-1} + 1)``, where ``n_p`` is the sample size of population ``p``. The entry at index ``(k_0, ..., k_{P-1})`` is the moment for branches subtending exactly ``k_p`` samples from population ``p``. The monomorphic bins (the all-zero and the full ``(n_0,...,n_{P-1})`` configuration) are zero by convention. """
[docs] def __init__( self, state_space: JointBlockCountingStateSpace, tree_height: 'TreeHeightDistribution', demography: Demography, reward: Reward = None ): """ Initialize the distribution. :param state_space: Joint block-counting state space. :param tree_height: The tree height distribution. :param demography: The demography. :param reward: The reward to multiply the joint SFS reward with. By default, the unit reward is used, which has no effect. """ if reward is None: reward = UnitReward() super().__init__( state_space=state_space, tree_height=tree_height, demography=demography, reward=reward )
@cached_property def shape(self) -> Tuple[int, ...]: """ Shape of the joint SFS array, ``(n_0 + 1, ..., n_{P-1} + 1)``. """ return tuple(int(n_p) + 1 for n_p in self.lineage_config.lineages) def _get_configs(self) -> List[Tuple[int, ...]]: """ Get the descendant vectors corresponding to (polymorphic) joint SFS bins, i.e. all block configurations except the full-sample configuration (which corresponds to the monomorphic, fixed sites). :return: List of descendant vectors. """ full = tuple(int(n_p) for n_p in self.lineage_config.lineages) return [c for c in self.state_space.block_configs if c != full]
[docs] def moment( self, k: int, start_time: float = None, end_time: float = None, center: bool = True, permute: bool = True ) -> np.ndarray: """ Get the kth moments of the joint site-frequency spectrum. :param k: The order of the moment. :param start_time: Time when to start accumulation of moments. By default, the start time specified when initializing the distribution. :param end_time: Time when to end accumulation of moments. By default, either the end time specified when initializing the distribution or the time until almost sure absorption. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. :return: An array of shape :attr:`shape` holding the kth moment of each joint SFS bin. """ # batched mean: all joint bins share one occupation-time vector, so the whole joint SFS mean is a single # contraction over the stacked bin rewards (closed form's spectrum path). Only for the plain mean to # absorption; other cases fall through to the per-bin accumulation. if ( Settings.closed_form_last_epoch and int(k) == 1 and start_time is None and end_time is None and self.tree_height.end_time is None ): occupation = self._occupation_times() if occupation is not None: m, idx_t = occupation base = np.asarray(self.reward._get(self.state_space), dtype=float) configs = self._get_configs() R = np.column_stack([ (base * np.asarray(JointSFSReward(config)._get(self.state_space), dtype=float))[idx_t] for config in configs ]) values = m @ R out = np.zeros(self.shape) for config, value in zip(configs, values): out[config] = value return JointSFS(out, pop_names=self.lineage_config.pop_names) # like the base distribution, a moment is the accumulation over the [start_time, end_time] window if start_time is None: start_time = self.tree_height.start_time if end_time is None: # evaluate the moment to absorption: signal the closed-form path with an infinite end time when it # applies (no explicit end time, accumulation from 0, and absorption certain in the last epoch), but not # when flattening applies (which takes precedence and delegates to the smaller lineage-counting space), # otherwise use the estimated absorption time if ( Settings.closed_form_last_epoch and not self._flattening_applies(k) and start_time == 0 and self.tree_height.end_time is None and self._absorption_certain_in_last_epoch() ): end_time = np.inf else: end_time = self.tree_height.t_max if start_time > 0: acc = self.accumulate(k, [start_time, end_time], center=center, permute=permute) out = acc[..., 1] - acc[..., 0] else: out = self.accumulate(k, [end_time], center=center, permute=permute)[..., 0] if np.isnan(out).any(): raise ValueError( "NaN value encountered when computing moment. " "This is likely due to an ill-conditioned rate matrix." ) return JointSFS(out, pop_names=self.lineage_config.pop_names)
[docs] def accumulate( self, k: int, end_times: Iterable[float], center: bool = True, permute: bool = True ) -> np.ndarray: """ Evaluate the kth moments of the joint site-frequency spectrum at different end times. :param k: The order of the moment. :param end_times: Times when to evaluate the moments. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. :return: Array of shape :attr:`shape` ``+ (len(end_times),)`` with each bin's kth moment over time. """ k = int(k) configs = self._get_configs() end_times = np.array(list(end_times)) accumulation = np.array([ PhaseTypeDistribution.accumulate( self, k=k, end_times=end_times, rewards=tuple(CombinedReward([self.reward, JointSFSReward(config)]) for _ in range(k)), center=center, permute=permute ) for config in configs ]) out = np.zeros(self.shape + (len(end_times),)) for config, acc in zip(configs, accumulation): out[config] = acc return out
[docs] def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, title: str = None ) -> 'plt.Axes': """ Plot accumulation of joint SFS moments over time, one curve per (polymorphic) bin. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. Defaults to 200 points up to the 99th percentile. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. :param ax: The axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param title: Title of the plot. :return: Axes. """ import matplotlib.pyplot as plt from ..visualization import Visualization k = int(k) if ax is None: ax = plt.gca() if end_times is None: end_times = np.linspace(0, self.tree_height.quantile(0.99), 200) end_times = np.asarray(list(end_times)) if title is None: title = f"Joint SFS moment accumulation (order {k})" configs = self._get_configs() accumulation = self.accumulate(k, end_times, center=center, permute=permute) for i, config in enumerate(configs): Visualization.plot( ax=ax, x=end_times, y=accumulation[config], xlabel='t', ylabel='moment', label=str(config), file=file, show=(i == len(configs) - 1) and show, clear=clear, title=title ) return ax
@cached_property def mean(self) -> JointSFS: """ Mean of the joint site-frequency spectrum, i.e. the expected branch length subtending each descendant configuration. """ return self.moment(k=1) @cached_property def var(self) -> JointSFS: """ Variance of the joint site-frequency spectrum. """ batched = self._cov_batched if batched is not None: configs, cov = batched out = np.zeros(self.shape) for a, config in enumerate(configs): out[config] = cov[a, a] return JointSFS(out, pop_names=self.lineage_config.pop_names) return self.moment(k=2, center=True)
[docs] def get_cov(self, config_a: Tuple[int, ...], config_b: Tuple[int, ...]) -> float: """ Get the covariance between the branch lengths subtending two descendant configurations. :param config_a: First descendant configuration. :param config_b: Second descendant configuration. :return: The covariance. """ return PhaseTypeDistribution.moment( self, k=2, center=True, rewards=tuple(CombinedReward([self.reward, JointSFSReward(c)]) for c in (config_a, config_b)) )
@cached_property def _cov_batched(self) -> Optional[Tuple[List[Tuple[int, ...]], np.ndarray]]: """ Batched joint-SFS covariance: all ``O(n^{2P})`` bin pairs share one two-point occupation operator ``K`` (see :meth:`_two_point_occupation`), so the whole covariance is ``cov = R^T (K + K^T) R - outer(mean)`` via a single contraction over the stacked bin rewards instead of a cross-moment per pair. Cached so that :attr:`cov` and :attr:`var` share the single (potentially expensive) ``K`` solve. :return: ``(configs, cov)`` with ``cov`` the bins-by-bins covariance over the polymorphic ``configs``, or ``None`` when not applicable (closed form disabled, explicit end time, or absorption not almost sure) so callers fall back. """ if not Settings.closed_form_last_epoch: return None two_point = self._two_point_occupation() if two_point is None: return None K, idx_t = two_point ss = self.state_space base = np.asarray(self.reward._get(ss), dtype=float) configs = self._get_configs() R = np.column_stack([ (base * np.asarray(JointSFSReward(config)._get(ss), dtype=float))[idx_t] for config in configs ]) sfs_matrix = R.T @ K @ R # R^T K R (one ordering) self._logger.debug("jsfs.cov: centering with the outer product of bin means") mean = np.array([self.mean.data[config] for config in configs]) cov = (sfs_matrix + sfs_matrix.T) - np.outer(mean, mean) return configs, cov @cached_property def cov(self) -> np.ndarray: """ Covariance between the branch lengths of all pairs of (polymorphic) joint SFS bins. Returned as an array of shape :attr:`shape` ``+`` :attr:`shape`, where ``cov[a_0, ..., a_{P-1}, b_0, ..., b_{P-1}]`` is the covariance between bins ``(a_0, ..., a_{P-1})`` and ``(b_0, ..., b_{P-1})``. """ batched = self._cov_batched if batched is not None: self._logger.debug("jsfs.cov: batched (shared two-point occupation)") configs, cov = batched out = np.zeros(self.shape + self.shape) for a, config_a in enumerate(configs): for b, config_b in enumerate(configs): out[tuple(config_a) + tuple(config_b)] = cov[a, b] return out configs = self._get_configs() pairs = [(a, b) for a in configs for b in configs] self._logger.debug("jsfs.cov: per-pair matrix exponential over %d config pairs", len(pairs)) results = [self.get_cov(a, b) for a, b in pairs] out = np.zeros(self.shape + self.shape) for (a, b), result in zip(pairs, results): out[tuple(a) + tuple(b)] = result return out
[docs] class TwoLocusSFSDistribution(PhaseTypeDistribution): """ Two-locus site-frequency spectrum under recombination. Entry ``(i, j)`` of the (symmetrized) mean is ``E[L^0_i · L^1_j]`` — the expected product of the branch length subtending ``i`` samples at locus 0 and ``j`` samples at locus 1 — computed as a second cross-moment of two per-locus SFS rewards on the two-locus block-counting state space. It reduces to ``Coalescent.sfs.cov`` (plus the outer product of the marginal means) as ``r → 0`` and to the outer product of the marginal SFS as ``r → ∞``. """
[docs] def __init__( self, state_space: TwoLocusBlockCountingStateSpace, tree_height: 'TreeHeightDistribution', demography: Demography, reward: Reward = None ): """ Initialize the distribution. :param state_space: Two-locus block-counting state space. :param tree_height: The (two-locus) tree height distribution, whose absorption time is when both loci have reached their MRCA. :param demography: The demography. :param reward: An optional reward to multiply the per-locus SFS rewards with. By default the unit reward. """ if reward is None: reward = UnitReward() super().__init__(state_space=state_space, tree_height=tree_height, demography=demography, reward=reward)
@cached_property def shape(self) -> Tuple[int, ...]: """ Shape of the two-locus SFS array, ``(n + 1, n + 1)`` (one axis per locus). """ n = int(self.lineage_config.n) return n + 1, n + 1 def _get_indices(self) -> List[int]: """ Polymorphic SFS bins ``1, ..., n - 1`` (the monomorphic ``0`` and ``n`` bins carry no information). """ return list(range(1, self.lineage_config.n)) @cached_property def mean(self) -> TwoLocusSFS: """ Mean two-locus SFS, ``E[L^0_i · L^1_j]`` for all polymorphic bins, symmetrized over the two loci. """ n = self.lineage_config.n indices = [(i, j) for i in self._get_indices() for j in self._get_indices()] results = [ PhaseTypeDistribution.moment( self, k=2, permute=False, center=False, rewards=( CombinedReward([self.reward, TwoLocusSFSReward(0, i)]), CombinedReward([self.reward, TwoLocusSFSReward(1, j)]) ) ) for i, j in indices ] out = np.zeros((n + 1, n + 1)) for (i, j), result in zip(indices, results): out[i, j] = result # symmetrize over the two (exchangeable) loci, as for the single-locus SFS covariance return TwoLocusSFS((out + out.T) / 2) @cached_property def corr(self) -> TwoLocusSFS: """ Pearson correlation between the locus-0 and locus-1 branch lengths, ``Corr(L^0_i, L^1_j) = (E[L^0_i L^1_j] - E[L^0_i] E[L^1_j]) / (sd(L^0_i) sd(L^1_j))``, for all polymorphic bins ``(i, j)``. This is the centered, scale-free companion to :attr:`mean` (which is the *uncentered* cross-moment ``E[L^0_i L^1_j]`` and therefore tends to the outer product of the marginal SFS means as the loci decouple). It is ``0`` as ``r → ∞`` (independent loci) and reduces to the single-locus SFS correlation as ``r → 0`` (fully linked). The per-locus means and variances are the marginals of the two-locus space and coincide for the two exchangeable loci. """ indices = self._get_indices() n = self.lineage_config.n # marginal locus-0 mean and variance per bin (identical for locus 1 by exchangeability, and independent of r) mean = { i: PhaseTypeDistribution.moment( self, k=1, center=False, rewards=(CombinedReward([self.reward, TwoLocusSFSReward(0, i)]),) ) for i in indices } var = { i: PhaseTypeDistribution.moment( self, k=2, center=True, rewards=(CombinedReward([self.reward, TwoLocusSFSReward(0, i)]),) * 2 ) for i in indices } cross = self.mean.data out = np.zeros((n + 1, n + 1)) for i in indices: for j in indices: denom = np.sqrt(var[i] * var[j]) if denom > 0: out[i, j] = (cross[i, j] - mean[i] * mean[j]) / denom return TwoLocusSFS(out)