Source code for phasegen.distributions.phase_type

"""Phase-type distribution (moment engine) and the tree-height distribution."""

import logging
from ..caching import cached_property, cache
from typing import Tuple, Collection, Iterable, Sequence, Union, TYPE_CHECKING
import numpy as np
from tqdm import tqdm
from ..demography import Demography, Epoch
from ..expm import Backend
from ..lineage import LineageConfig
from ..locus import LocusConfig
from ..rewards import Reward, TreeHeightReward
from ..settings import Settings
from ..spectrum import SFS
from ..state_space import LineageCountingStateSpace, StateSpace

from .base import DensityAwareDistribution, MarginalDemeDistributions, MarginalLocusDistributions, MomentAwareDistribution
from ._moments import MomentEvaluator

if TYPE_CHECKING:
    from matplotlib import pyplot as plt

expm = Backend.expm
logger = logging.getLogger('phasegen')


[docs] class PhaseTypeDistribution(MomentEvaluator, MomentAwareDistribution): """ Phase-type distribution for a piecewise time-homogeneous process. """
[docs] def __init__( self, state_space: StateSpace, tree_height: 'TreeHeightDistribution', demography: Demography = None, reward: Reward = None ): """ Initialize the distribution. :param state_space: The state space. :param tree_height: The tree height distribution. :param demography: The demography. :param reward: The reward. By default, the tree height reward. """ if demography is None: demography = Demography() if reward is None: reward = TreeHeightReward() super().__init__() #: Population configuration self.lineage_config: LineageConfig = state_space.lineage_config #: Locus configuration self.locus_config: LocusConfig = state_space.locus_config #: Reward self.reward: Reward = reward #: State space self.state_space: StateSpace = state_space #: Demography self.demography: Demography = demography #: Tree height distribution self.tree_height: TreeHeightDistribution = tree_height
@cached_property def mean(self) -> float | SFS: """ First moment / mean. """ return self.moment(k=1) @cached_property def var(self) -> float | SFS: """ Second central moment / variance. """ return self.moment(k=2, center=True) @cached_property def std(self) -> float | SFS: """ Standard deviation. """ return self.var ** 0.5 @cached_property def m2(self) -> float | SFS: """ Second (non-central) moment. """ return self.moment(k=2, center=False) @cached_property def demes(self) -> MarginalDemeDistributions: """ Marginal distributions over each deme. """ return MarginalDemeDistributions(self) @cached_property def loci(self) -> MarginalLocusDistributions: """ Marginal distributions over each locus. """ return MarginalLocusDistributions(self) def _sample( self, n_samples: int, rewards: Sequence[Reward] = None, record_visits: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Generate samples from the mean reward distribution by simulating trajectories. :param n_samples: Number of trajectories to simulate. :param rewards: Rewards to sample from. Default is the tree height reward. :param record_visits: Whether to record which states were visited during the sampling. :return: Array of sampled rewards of size (n_samples, len(rewards)), and optionally an array of probabilities of visiting each state. """ if rewards is None: rewards = [self.reward] n_rewards = len(rewards) samples = np.zeros((n_samples, n_rewards)) absorbing = np.array([s.is_absorbing() for s in self.state_space.states]) alpha = self.state_space.alpha states_visited = np.zeros_like(alpha) R = np.array([r._get(self.state_space) for r in rewards]) # iterate over samples for i in tqdm(range(n_samples), disable=not Settings.use_pbar): mass = np.zeros(n_rewards) t = 0 rate = 0 state = np.random.choice(len(alpha), p=alpha) epochs = self.demography.epochs try: # find first non-zero rate epoch while rate == 0: epoch = next(epochs) self.state_space.update_epoch(epoch) rate = -self.state_space.S[state, state] # sample next time step dt = np.random.exponential(1 / rate) # iterate over transitions while True: # iterate over epochs while t + dt >= epoch.end_time: # reward until epoch boundary mass += R[:, state] * (epoch.end_time - t) dt -= (epoch.end_time - t) t = epoch.end_time # advance epoch epoch = next(epochs) self.state_space.update_epoch(epoch) new_rate = -self.state_space.S[state, state] if new_rate == 0: t = epoch.end_time continue # rescale remaining time dt *= rate / new_rate rate = new_rate # step completes in current epoch mass += R[:, state] * dt t += dt # sample next state probs = self.state_space.S[state].copy() probs[state] = 0 state = np.random.choice(len(probs), p=probs / rate) states_visited[state] += 1 if absorbing[state]: raise StopIteration rate = -self.state_space.S[state, state] # if rate is zero, we skip to the next epoch if rate == 0: t = epoch.end_time continue # sample next time step dt = np.random.exponential(1 / rate) except StopIteration: pass samples[i] = mass # normalize states visited states_visited /= n_samples return (samples, states_visited) if record_visits else samples
[docs] def plot_accumulation( self, k: int = 1, end_times: Iterable[float] = None, rewards: Sequence[Reward] = None, center: bool = True, permute: bool = True, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = None ) -> 'plt.Axes': """ Plot accumulation of (non-central) moments at different times. .. note:: This is different from a CDF, as it shows the accumulation of moments rather than the probability of having reached absorption at a certain time. :param k: The order of the moment. :param end_times: Times when to evaluate the moment. By default, 200 evenly spaced values between 0 and the 99th percentile. :param rewards: Sequence of k rewards. By default, the reward of the underlying distribution. :param center: Whether to center the moment around the mean. :param permute: For cross-moments, whether to average over all permutations of rewards. Default is ``True``, which will provide the correct cross-moment. If set to ``False``, the cross-moment will be conditioned on the order of rewards. :param ax: The axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ k = int(k) from ..visualization import Visualization if end_times is None: end_times = np.linspace(0, self.tree_height.quantile(0.99), 200) if rewards is None: rewards = (self.reward,) * k if title is None: title = f"Moment accumulation ({', '.join(r.__class__.__name__.replace('Reward', '') for r in rewards)})" y = self.accumulate(k, end_times, rewards, center, permute) Visualization.plot( ax=ax, x=end_times, y=y, xlabel='t', ylabel='moment', label=label, file=file, show=show, clear=clear, title=title )
[docs] class TreeHeightDistribution(PhaseTypeDistribution, DensityAwareDistribution): """ Phase-type distribution for a piecewise time-homogeneous process that allows the computation of the density function. This is currently only possible with default rewards. """ #: Maximum number of epochs to consider when determining time to almost sure absorption. max_epochs: int = 10000 #: Maximum number of time we double the end time when determining time to almost sure absorption. max_iter: int = 20 #: Probability of almost sure absorption. p_absorption: float = 1 - 1e-15
[docs] def __init__( self, state_space: LineageCountingStateSpace, demography: Demography = None, start_time: float = 0, end_time: float = None ): """ Initialize the distribution. :param state_space: The state space. :param demography: The demography. :param start_time: Time when to start accumulating moments. :param end_time: Time when to end accumulation of moments. By default, the time until almost sure absorption. """ if start_time < 0: raise ValueError("Start time must be greater than or equal to 0.") if end_time is not None and end_time < 0: raise ValueError("End time must be greater than or equal to 0.") if end_time is not None and end_time < start_time: raise ValueError("End time must be greater than equal start time.") super().__init__( state_space=state_space, tree_height=self, demography=demography, reward=TreeHeightReward() ) #: State space self.state_space: LineageCountingStateSpace = state_space #: Start time self.start_time: float = start_time #: End time self.end_time: float | None = end_time
[docs] def cdf(self, t: float | Sequence[float]) -> float | np.ndarray: """ Cumulative distribution function. :param t: Value or values to evaluate the CDF at. :return: Cumulative probability :raises NotImplementedError: if rewards are not default """ # raise error if rewards are not default if not isinstance(self.reward, TreeHeightReward): raise NotImplementedError("PDF not implemented for non-default rewards.") # assume scalar if not array if not isinstance(t, Iterable): return self.cdf(np.array([t]))[0] # check for negative values if np.any(t < 0): raise ValueError("Negative values are not allowed.") # sort array in ascending order but keep track of original indices t_sorted: Collection[float] = np.sort(t).astype(float) epochs = enumerate(self.demography.epochs) i_epoch, epoch = next(epochs) # get the transition matrix for the first epoch self.state_space.update_epoch(epoch) # initialize transition matrix T = np.eye(self.state_space.k) u_prev = 0 # initialize probabilities probs = np.zeros_like(t_sorted) # take reward vector as exit vector e = self.reward._get(self.state_space) # iterate through sorted values for i, u in enumerate(t_sorted): # iterate over epochs between u_prev and u while u > epoch.end_time: self._check_numerical_stability(self.state_space.S, i_epoch) # update transition matrix with remaining time in current epoch T @= expm(self._dense_rate_matrix() * (epoch.end_time - u_prev)) # fetch and update for next epoch u_prev = epoch.end_time i_epoch, epoch = next(epochs) self.state_space.update_epoch(epoch) self._check_numerical_stability(self.state_space.S, i_epoch) # update transition matrix with remaining time in current epoch T @= expm(self._dense_rate_matrix() * (u - u_prev)) probs[i] = 1 - self.state_space.alpha @ T @ e u_prev = u # sort probabilities back to original order probs = probs[np.argsort(t)] if np.isnan(probs).any(): self._logger.critical( "NaN values in CDF. This is likely due to an ill-conditioned rate matrix." ) return probs
def _update( self, u: float, u_prev: float, T: np.ndarray, epoch: 'Epoch' ) -> Tuple[float, np.ndarray, 'Epoch']: """ Update transition matrix and time. :param u: Time to update to. :param u_prev: Previous time. :param T: Transition matrix. :param epoch: Current epoch. :return: Updated time, transition matrix, and epoch. """ self.state_space.update_epoch(epoch) while u > epoch.end_time: # update transition matrix with remaining time in current epoch tau = epoch.end_time - u_prev T = T @ expm(self._dense_rate_matrix() * tau) u_prev = epoch.end_time # fetch and update for next epoch epoch = self.demography.get_epoch(epoch.end_time) self.state_space.update_epoch(epoch) else: # update transition matrix T = T @ expm(self._dense_rate_matrix() * (u - u_prev)) return u, T, epoch @cached_property def _e(self) -> np.ndarray: """ Exit vector. """ return self.reward._get(self.state_space) def _cum(self, T: np.ndarray) -> float: """ Get cumulative probability for given transition matrix. :param T: Transition matrix. :return: Cumulative probability. """ return float(1 - self.state_space.alpha @ T @ self._e)
[docs] @cache def quantile( self, q: float, expansion_factor: float = 2, precision: float = 1e-5, max_iter: int = 1000 ): """ Find the specified quantile of a CDF using an adaptive bisection method. :param q: The desired quantile (between 0 and 1). :param expansion_factor: Factor by which to multiply the upper bound that does not yet contain the quantile. :param precision: Precision for quantile, i.e. ``F(b) - F(a) < precision``. :param max_iter: Maximum number of iterations. :return: The quantile. """ if q < 0 or q > 1: raise ValueError("Specified quantile must be between 0 and 1.") if expansion_factor <= 1: raise ValueError("Expansion factor must be greater than 1.") # finite upper bound for the search: the time of almost-sure absorption (any quantile q < 1 lies below it). # This also guards against a demography that never absorbs — ``_get_absorption_time`` raises in that case — # and keeps the expansion below from doubling ``b`` to an overflow-inducing ceiling. A user-supplied end # time bounds the (necessarily proper) distribution instead. b_max = self.end_time if self.end_time is not None else self._get_absorption_time() # initialize bounds a, b = 0, 1 T_a = np.eye(self.state_space.k) epoch_a, epoch_b = self.demography.get_epoch(0), self.demography.get_epoch(0) b, T_b, epoch_b = self._update(min(b, b_max), a, T_a, epoch_b) i = 0 # expand the upper bound until its CDF reaches q (bounded by the absorption time, so it always terminates) while self._cum(T_b) < q and b < b_max and i < max_iter: b, T_b, epoch_b = self._update(min(b * expansion_factor, b_max), b, T_b, epoch_b) i += 1 # use bisection method within the determined bounds while self._cum(T_b) - self._cum(T_a) > precision and i < max_iter: m, T_m, epoch_m = self._update((a + b) / 2, a, T_a, epoch_a) if self._cum(T_m) < q: a, T_a, epoch_a = m, T_m, epoch_m else: b, T_b, epoch_b = m, T_m, epoch_m i += 1 # warn if maximum number of iterations reached if i - 1 == max_iter: raise RuntimeError("Maximum number of iterations reached when determining quantile.") return (a + b) / 2
[docs] def pdf(self, t: float | Sequence[float], dx: float = None) -> float | np.ndarray: """ Density function. We use numerical differentiation of the CDF to calculate the density. This provides good results as the CDF is exact and continuous. :param t: Value or values to evaluate the density function at. :param dx: Step size for numerical differentiation. By default, the 99th percentile divided by 1e10. :return: Density """ if dx is None: dx = self.quantile(0.99) / 1e10 if isinstance(t, Iterable): t = np.array(t) # determine (non-negative) evaluation points x1 = np.max([t - dx / 2, np.zeros_like(t)], axis=0) x2 = x1 + dx return (self.cdf(x2) - self.cdf(x1)) / dx
@cached_property def t_max(self) -> float: """ Time until which computations are performed. This is either the end time specified when initializing the distribution or the time until almost sure absorption. """ if self.end_time is not None: return self.end_time t_abs = self._get_absorption_time() if t_abs < self.start_time: raise ValueError( f"Determined time of almost sure absorption ({t_abs:.1f}) " f"is smaller than start time ({self.start_time:.1f}). " "The start time may be too large or the demography not well-defined." ) return t_abs def _get_absorption_time(self) -> float: """ Get a time estimate for when we have reached absorption almost surely. We base this computation on the transition matrix rather than the moments, because here we have a good idea about how likely absorption is, and can warn the user if necessary. Stopping the computation when no more rewards are accumulated is not a good idea, as this can happen before almost sure absorption (exponential runaway growth, temporary isolation in different demes). """ i = 0 T = np.eye(self.state_space.k) epoch = self.demography.get_epoch(0) self._check_demography_conditioning() t = 2 ** int(np.log2(np.mean(list(epoch.pop_sizes.values())))) expansion_factor = 2 t, T, epoch = self._update(t, 0, T, epoch) p = self._cum(T) # multiple time by expansion_factor until we reach p_absorption while p < self.p_absorption and i < self.max_iter: t, T, epoch = self._update(t * expansion_factor, t, T, epoch) p = self._cum(T) if np.isnan(p): self._logger.critical( "Could not reliably find time of almost sure absorption " "as probability of absorption is NaN. " "This is likely due to an ill-conditioned rate matrix. " f"Using time {t:.1f}. " ) i += 1 # if absorption was not reached, fail loudly for a demography that *never* absorbs rather than returning the # doubling ceiling (see :meth:`_assert_absorbs`). if p < self.p_absorption and not np.isnan(p): self._assert_absorbs(T) if i - 1 == self.max_iter: self._logger.warning( "Could not reliably find time of almost sure absorption after maximum number of iterations. " f"Using time {t:.1f} with probability of absorption 1 - {1 - p:.1e}. " "This could be due to numerical imprecision, unreachable states or very large or small " "absorption times. You can set the end time manually (see `Coalescent.end_time`) or increase " "the maximum number of iterations (`TreeHeightDistribution.max_iter`)." ) return t def _empirical_cdf(self, n_samples: int, reward: Reward = None, t: float | Sequence[float] = None) -> np.ndarray: """ Generate an empirical cumulative distribution function (CDF) by sampling from the distribution. :param n_samples: Number of samples to generate. :param reward: Reward function to use for sampling. If not specified, the default reward of the distribution is used. :param t: Values at which to evaluate the CDF. Default to 100 evenly spaced values between 0 and the 99th percentile. :return: Sorted array of sampled total rewards. """ if t is None: t = np.linspace(0, self.tree_height.quantile(0.99), 100) samples = self._sample(n_samples, reward).reshape(n_samples) x = np.sort(samples) y = np.arange(1, n_samples + 1) / n_samples if x.ndim == 1: return np.interp(t, x, y) def _plot_empirical_cdf( self, n_samples: int = 1000, reward: Reward = None, t: float | Sequence[float] = None, ax: 'plt.Axes' = None, show: bool = True, file: str = None, clear: bool = True, label: str = None, title: str = 'Empirical CDF' ) -> 'plt.Axes': """ Plot the empirical cumulative distribution function (CDF). :param n_samples: Number of samples to generate. :param reward: Reward function to use for sampling. If not specified, the default reward of the distribution is used. :param t: Values at which to evaluate the CDF. Default to 100 evenly spaced values between 0 and the 99th percentile. :param ax: Axes to plot on. :param show: Whether to show the plot. :param file: File to save the plot to. :param clear: Whether to clear the plot before plotting. :param label: Label for the plot. :param title: Title of the plot. :return: Axes. """ from ..visualization import Visualization if t is None: t = np.linspace(0, self.tree_height.quantile(0.99), 100) y = self._empirical_cdf(n_samples, reward, t) return Visualization.plot( ax=ax, x=t, y=y, xlabel='t', ylabel='F(t)', label=label, file=file, show=show, clear=clear, title=title )