"""Distribution base classes and marginal (per-deme / per-locus) views."""
import logging
from abc import ABC, abstractmethod
from collections.abc import Mapping
from ..caching import cached_property
from typing import Iterator, Sequence, TYPE_CHECKING
import numpy as np
from ..expm import Backend
from ..rewards import DemeReward, LocusReward, CombinedReward
if TYPE_CHECKING:
from matplotlib import pyplot as plt
from .phase_type import PhaseTypeDistribution
expm = Backend.expm
logger = logging.getLogger('phasegen')
class ProbabilityDistribution(ABC):
"""
Abstract base class for probability distributions for which moments can be calculated.
"""
def __init__(self):
"""
Create object.
"""
#: Logger
self._logger = logger.getChild(self.__class__.__name__)
def touch(self, **kwargs: dict):
"""
Touch all cached properties.
:param kwargs: Additional keyword arguments.
"""
for cls in self.__class__.__mro__:
for attr, value in cls.__dict__.items():
if isinstance(value, cached_property):
getattr(self, attr)
class MomentAwareDistribution(ProbabilityDistribution, ABC):
"""
Abstract base class for probability distributions for which moments can be calculated.
"""
@abstractmethod
@cached_property
def mean(self) -> float:
"""
First moment / mean.
"""
pass
@abstractmethod
@cached_property
def var(self) -> float:
"""
Second central moment / variance.
"""
pass
@abstractmethod
@cached_property
def m2(self) -> float:
"""
Second (non-central) moment.
"""
pass
class MarginalDistributions(Mapping, ABC):
"""
Base class for marginal distributions.
"""
@abstractmethod
@cached_property
def cov(self) -> np.ndarray:
"""
Covariance matrix.
"""
pass
@abstractmethod
@cached_property
def corr(self) -> np.ndarray:
"""
Correlation matrix.
"""
pass
@abstractmethod
def get_cov(self, d1, d2) -> float:
"""
Get the covariance between two marginal distributions.
:param d1: The index of the first marginal distribution.
:param d2: The index of the second marginal distribution.
:return: covariance
"""
pass
@abstractmethod
def get_corr(self, d1, d2) -> float:
"""
Get the correlation coefficient between two marginal distributions.
:param d1: The index of the first marginal distribution.
:param d2: The index of the second marginal distribution.
:return: correlation coefficient
"""
pass
[docs]
class MarginalLocusDistributions(MarginalDistributions):
"""
Marginal locus distributions.
"""
[docs]
def __init__(self, dist: 'PhaseTypeDistribution'):
"""
Initialize the distributions.
:param dist: The distribution.
"""
self.dist = dist
def __getitem__(self, item):
"""
Get the distribution for the given locus.
:param item: Deme name.
:return: Distribution.
"""
return self.loci[item]
def __iter__(self) -> Iterator:
"""
Iterate over distributions.
:return: Iterator.
"""
return iter(self.loci)
def __len__(self) -> int:
"""
Get the number of distributions.
:return: Number of distributions.
"""
return len(self.loci)
@cached_property
def loci(self) -> 'MarginalLocusDistributions':
"""
Distributions marginalized over loci.
"""
# get class of distribution but use PhaseTypeDistribution
# if this is a TreeHeightDistribution as TreeHeightDistribution
# only works with default rewards
from .phase_type import PhaseTypeDistribution, TreeHeightDistribution
cls = self.dist.__class__ if not isinstance(self.dist, TreeHeightDistribution) else PhaseTypeDistribution
loci = {}
for locus in range(self.dist.locus_config.n):
loci[locus] = cls(
state_space=self.dist.state_space,
tree_height=self.dist.tree_height,
demography=self.dist.demography,
reward=CombinedReward([self.dist.reward, LocusReward(locus)])
)
return loci
[docs]
def get_cov(self, locus1: int, locus2: int) -> float:
"""
Get the covariance between two loci.
:param locus1: The first locus.
:param locus2: The second locus.
:return: The covariance.
"""
locus1 = int(locus1)
locus2 = int(locus2)
if locus1 not in range(self.dist.locus_config.n) or locus2 not in range(self.dist.locus_config.n):
raise ValueError(f"Locus {locus1} or {locus2} does not exist.")
return self.dist.moment(
k=2,
rewards=(
CombinedReward([self.dist.reward, LocusReward(locus1)]),
CombinedReward([self.dist.reward, LocusReward(locus2)])
),
center=True
)
@cached_property
def cov(self) -> np.ndarray:
"""
Covariance matrix across loci.
"""
n_loci = self.dist.locus_config.n
return np.array([[self.get_cov(i, j) for i in range(n_loci)] for j in range(n_loci)])
[docs]
def get_corr(self, locus1: int, locus2: int) -> float:
"""
Get the correlation coefficient between two loci.
:param locus1: The first locus.
:param locus2: The second locus.
:return: The correlation coefficient.
"""
locus1 = int(locus1)
locus2 = int(locus2)
return self.get_cov(locus1, locus2) / (self.loci[locus1].std * self.loci[locus2].std)
@cached_property
def corr(self) -> np.ndarray:
"""
Correlation matrix across loci.
"""
n_loci = self.dist.locus_config.n
return np.array([[self.get_corr(i, j) for i in range(n_loci)] for j in range(n_loci)])
[docs]
class MarginalDemeDistributions(MarginalDistributions):
"""
Marginal deme distributions.
"""
[docs]
def __init__(self, dist: 'PhaseTypeDistribution'):
"""
Initialize the distributions.
:param dist: The distribution.
"""
self.dist = dist
def __getitem__(self, item):
"""
Get the distribution for the given deme.
:param item: Deme name.
:return: Distribution.
"""
return self.demes[item]
def __iter__(self) -> Iterator:
"""
Iterate over distributions.
:return: Iterator.
"""
return iter(self.demes)
def __len__(self) -> int:
"""
Get the number of distributions.
:return: Number of distributions.
"""
return len(self.demes)
@cached_property
def demes(self) -> 'MarginalDemeDistributions':
"""
Distributions marginalized over demes.
"""
# get class of distribution but use PhaseTypeDistribution
# if this is a TreeHeightDistribution as TreeHeightDistribution
# only works with default rewards
from .phase_type import PhaseTypeDistribution, TreeHeightDistribution
cls = self.dist.__class__ if not isinstance(self.dist, TreeHeightDistribution) else PhaseTypeDistribution
demes = {}
for pop in self.dist.lineage_config.pop_names:
demes[pop] = cls(
state_space=self.dist.state_space,
tree_height=self.dist.tree_height,
demography=self.dist.demography,
reward=CombinedReward([self.dist.reward, DemeReward(pop)])
)
return demes
[docs]
def get_cov(self, pop1: str, pop2: str) -> float:
"""
Get the covariance between two demes.
:param pop1: The first deme.
:param pop2: The second deme.
:return: The covariance.
"""
if pop1 not in self.dist.lineage_config.pop_names or pop2 not in self.dist.lineage_config.pop_names:
raise ValueError(f"Population {pop1} or {pop2} does not exist.")
return self.dist.moment(
k=2,
rewards=(
CombinedReward([self.dist.reward, DemeReward(pop1)]),
CombinedReward([self.dist.reward, DemeReward(pop2)])
),
center=True
)
@cached_property
def cov(self) -> np.ndarray:
"""
Covariance matrix across demes.
"""
pops = self.dist.lineage_config.pop_names
return np.array([[self.get_cov(p1, p2) for p1 in pops] for p2 in pops])
[docs]
def get_corr(self, pop1: str, pop2: str) -> float:
"""
Get the correlation coefficient between two demes.
:param pop1: The first deme.
:param pop2: The second deme.
:return: The correlation coefficient.
"""
return self.get_cov(pop1, pop2) / (self.demes[pop1].std * self.demes[pop2].std)
@cached_property
def corr(self) -> np.ndarray:
"""
Correlation matrix across demes.
"""
pops = self.dist.lineage_config.pop_names
return np.array([[self.get_corr(p1, p2) for p1 in pops] for p2 in pops])
class DensityAwareDistribution(MomentAwareDistribution, ABC):
"""
Abstract base class for probability distributions for which moments and densities can be calculated.
"""
@abstractmethod
def cdf(self, t: float | Sequence[float]) -> float | np.ndarray:
"""
Cumulative distribution function.
:param t: Value or values to evaluate the CDF at.
:return: CDF.
"""
pass
@abstractmethod
def quantile(self, q: float) -> float:
"""
Get the qth quantile.
"""
pass
@abstractmethod
def pdf(self, t: float | Sequence[float], **kwargs: dict) -> float | np.ndarray:
"""
Density function.
:param t: Value or values to evaluate the density function at.
:param kwargs: Additional keyword arguments.
:return: Density.
"""
pass
def plot_cdf(
self,
ax: 'plt.Axes' = None,
t: np.ndarray = None,
show: bool = True,
file: str = None,
clear: bool = True,
label: str = None,
title: str = 'Tree height CDF'
) -> 'plt.Axes':
"""
Plot cumulative distribution function.
:param ax: Axes to plot on.
:param t: Values to evaluate the CDF at. By default, 200 evenly spaced values between 0 and the 99th percentile.
:param show: Whether to show the plot.
:param file: File to save the plot to.
:param clear: Whether to clear the plot before plotting.
:param label: Label for the plot.
:param title: Title of the plot.
:return: Axes.
"""
from ..visualization import Visualization
if t is None:
t = np.linspace(0, self.quantile(0.99), 200)
return Visualization.plot(
ax=ax,
x=t,
y=self.cdf(t),
xlabel='t',
ylabel='F(t)',
label=label,
file=file,
show=show,
clear=clear,
title=title
)
def plot_pdf(
self,
ax: 'plt.Axes' = None,
t: np.ndarray = None,
show: bool = True,
file: str = None,
clear: bool = True,
label: str = None,
title: str = 'Tree height PDF',
dx: float = None
) -> 'plt.Axes':
"""
Plot density function.
:param ax: The axes to plot on.
:param t: Values to evaluate the density function at.
By default, 200 evenly spaced values between 0 and the 99th percentile.
:param show: Whether to show the plot.
:param file: File to save the plot to.
:param clear: Whether to clear the plot before plotting.
:param label: Label for the plot.
:param title: Title of the plot.
:param dx: Step size for numerical differentiation. By default, the 99th percentile divided by 1e10.
:return: Axes.
"""
from ..visualization import Visualization
if dx is None:
dx = self.quantile(0.99) / 1e10
if t is None:
t = np.linspace(0, self.quantile(0.99), 200)
return Visualization.plot(
ax=ax,
x=t,
y=self.pdf(t, dx=dx),
xlabel='t',
ylabel='f(t)',
label=label,
file=file,
show=show,
clear=clear,
title=title
)