Source code for phasegen.demography

"""
Demographic events and demography class.
"""

import itertools
import logging
from abc import abstractmethod, ABC
from collections import defaultdict
from functools import cached_property
from typing import List, Callable, Dict, Iterable, Tuple, Any, Iterator, Sequence

import numpy as np

logger = logging.getLogger('phasegen')


[docs] class Demography: """ Class storing full demographic information. """ #: Population names. pop_names: List[str] #: Number of populations. n_pops: int
[docs] def __init__( self, events: List['DemographicEvent'] = None, pop_sizes: Dict[str, Dict[float, float]] | Dict[str, float] | float = None, migration_rates: Dict[Tuple[str, str], Dict[float, float]] | Dict[Tuple[str, str], float] = None, warn_n_epochs: int = 20 ): """ Initialize the demography. :param events: List of demographic events. :param pop_sizes: Population sizes. Either a dictionary of the form ``{pop_i: {time1: size1, time2: size2}}``, indexed by population name and time at which the population size changes, or a dictionary of the form ``{pop_i: size}`` if the population size is constant, or a single float if there is only one population and the population size is constant. :param migration_rates: Migration rates. A dictionary of the form ``{(pop_i, pop_j): {time1: rate1, time2: rate2}}`` of migration from population ``pop_i`` to population ``pop_j`` at time ``time1`` etc. or alternatively a dictionary of the form ``{(pop_i, pop_j): rate}`` if the migration rate is constant over time. :param warn_n_epochs: Threshold for the number of epochs considered after which a warning is issued. """ if events is None: events = [] if pop_sizes is None: pop_sizes = {} # wrap population size in dictionary if it is a single float elif isinstance(pop_sizes, (float, int)): pop_sizes = {'pop_0': {0: pop_sizes}} # wrap population size in dictionary if only one time per population is given elif isinstance(pop_sizes, dict) and isinstance(list(pop_sizes.values())[0], (float, int)): pop_sizes = {p: {0: s} for p, s in pop_sizes.items()} if migration_rates is None: migration_rates = {} # wrap migration rate in dictionary if only one time per migration pair is given elif isinstance(migration_rates, dict) and isinstance(list(migration_rates.values())[0], (float, int)): migration_rates = {(p, q): {0: r} for (p, q), r in migration_rates.items()} #: The logger instance self._logger = logger.getChild(self.__class__.__name__) #: Threshold for the number of epochs considered after which a warning is issued. self.warn_n_epochs: int = int(warn_n_epochs) #: Whether a warning about the number of epochs has been already issued. self._issued_warning = False #: Array of demographic events. self.events: List[DemographicEvent] = list(events) # add population size and migration rate changes if specified if len(pop_sizes) or len(migration_rates): self.events += [DiscreteRateChanges(pop_sizes=pop_sizes, migration_rates=migration_rates)] # prepare events self._prepare_events() # issue warning if multiple populations are specified but no migration rates are given if self.n_pops > 1 and migration_rates == {} and len(events) == 0: self._logger.warning( 'Multiple populations are specified, but no migration rates were given so far. ' 'Initializing with zero migration rates between all populations. ' 'Note that this may lead to infinite coalescence times if not changed later.' )
def _prepare_events(self): """ Sort events by start time and determine population names and number of populations. """ # sort events by start time self.events = sorted(self.events, key=lambda e: e.start_time) # determine population names self.pop_names = sorted(list(set([p for e in self.events for p in e.pop_names]))) # determine number of populations self.n_pops = len(self.pop_names)
[docs] def to_msprime( self, max_epochs: int = 1000 ) -> 'msprime.Demography': """ Convert to an Msprime demography object. :param max_epochs: Maximum number of epochs to use. Note that the number of epochs may be infinite. :return: msprime demography object. :raise ImportError: If Msprime is not installed. """ try: import msprime as ms except ImportError: raise ImportError('Msprime must be installed to use this method.') self._prepare_events() first_epoch = next(self.epochs) # create demography object d: ms.Demography = ms.Demography( populations=[ms.Population(name=pop, initial_size=first_epoch.pop_sizes[pop]) for pop in self.pop_names], migration_matrix=np.array([[first_epoch.migration_rates[(p, q)] for q in self.pop_names] for p in self.pop_names]) ) for epoch in itertools.islice(self.epochs, 1, int(max_epochs) + 1): # iterate over populations for pop in self.pop_names: # add population size changes # noinspection PyTypeChecker d.add_population_parameters_change( time=epoch.start_time, initial_size=epoch.pop_sizes[pop], population=pop ) # iterate over migration rates for (p, q) in itertools.product(self.pop_names, repeat=2): if p != q: # noinspection all d.add_migration_rate_change( time=epoch.start_time, rate=epoch.migration_rates[(p, q)], source=p, dest=q ) # sort events by time d.sort_events() return d
def _to_demes(self) -> 'demes.Graph': """ Convert to demes object (see https://tskit.dev/msprime/docs/stable/api.html#msprime.Demography.to_demes). TODO: msprime raises an error when converting to demes (migration[0]: invalid migration) :return: Demes object. :raise ImportError: If msprime is not installed. """ self.to_msprime().to_demes() @property def epochs(self) -> Iterator['Epoch']: """ Get a generator for the epochs. """ self._prepare_events() prev = Epoch( start_time=0, end_time=0, pop_sizes={p: 1 for p in self.pop_names}, migration_rates={k: 0 for k in itertools.product(self.pop_names, repeat=2)} ) i = 0 while True: # issue warning if number of epochs exceeds threshold if i == self.warn_n_epochs and not self._issued_warning: self._logger.warning( f'Number of epochs considered exceeds {self.warn_n_epochs}. ' 'Note that the runtime increases linearly with the number of epochs.' ) self._issued_warning = True # potential next epoch epoch = Epoch( start_time=prev.end_time, end_time=np.inf, pop_sizes=prev.pop_sizes, migration_rates=prev.migration_rates ) # broadcast events for e in self.events: # adjust end time e._broadcast(epoch) # apply the events to the epoch [e._apply(epoch) for e in self.events] yield epoch prev = epoch if epoch.end_time == np.inf: break i += 1
[docs] def has_n_epochs(self, n: int) -> bool: """ Check whether the demography has at least `n` epochs. :param n: Number of epochs. :return: Whether the demography has at least `n` epochs. """ # get epoch iterator epochs = self.epochs for _ in range(int(n)): try: next(epochs) except StopIteration: return False return True
[docs] def get_epochs(self, t: Iterable[float]) -> Sequence['Epoch']: """ Get the epochs at the given times. :param t: Times. :return: Array of epochs. """ t = list(t) # sort times in ascending order t_sorted: Sequence[float] = np.sort(t) # get epoch iterator iterator: Iterator[Epoch] = self.epochs # get first epoch epoch = next(iterator) # initialize array of epochs epochs = np.zeros_like(t_sorted, dtype=Epoch) for i, time in enumerate(t_sorted): # wind forward until we reach the epoch enclosing the current time while not epoch.start_time <= time < epoch.end_time: epoch = next(iterator) # add epoch to array epochs[i] = epoch # sort back to original order return np.array(epochs[np.argsort(t)])
[docs] def get_epoch(self, t: float = 0) -> 'Epoch': """ Get the epoch at the given time. :param t: Time. :return: Epoch. """ return self.get_epochs([t])[0]
[docs] def add_events(self, events: List['DemographicEvent']): """ Add demographic events. :param events: List of demographic events. """ self.events += events self._prepare_events()
[docs] def add_event(self, event: 'DemographicEvent'): """ Add a demographic event. :param event: Demographic event. """ self.add_events([event])
[docs] def plot_pop_sizes( self, t: np.ndarray = None, show: bool = True, file: str = None, title: str = 'Population size trajectory', ylabel: str = '$N_e$', ax: 'plt.Axes' = None, kwargs: dict = None ) -> 'plt.Axes': """ Plot the population size over time. :param t: Times at which to plot the population sizes. By default, we use 1000 time points between time 0 and 10. :param show: Whether to show the plot. :param file: File to save the plot to. :param title: Title of the plot. :param title: Title of the plot. :param ylabel: Label of the y-axis. :param ax: Axes object to plot to. :param kwargs: Keyword arguments to pass to the plotting function. :return: Axes object. """ from .visualization import Visualization if t is None: t = np.linspace(0, 10, 1000) if kwargs is None: kwargs = {} return Visualization.plot_rates( times=list(t), rates=dict(zip( self.pop_names, np.array([[e.pop_sizes[p] for p in self.pop_names] for e in self.get_epochs(t)]).T )), show=show, file=file, title=title, ylabel=ylabel, kwargs=kwargs, ax=ax )
[docs] def plot_migration( self, t: np.ndarray = None, show: bool = True, file: str = None, title: str = 'Migration rate trajectory', ylabel: str = '$m_{ij}$', ax: 'plt.Axes' = None, kwargs: dict = None ) -> 'plt.Axes': """ Plot the migration over time. :param t: Times at which to plot the migration rates. By default, we use 1000 time points between time 0 and 10. :param show: Whether to show the plot. :param file: File to save the plot to. :param title: Title of the plot. :param ylabel: Label of the y-axis. :param ax: Axes object to plot to. :param kwargs: Keyword arguments to pass to the plotting function. :return: Axes object. """ from .visualization import Visualization if t is None: t = np.linspace(0, 10, 1000) if kwargs is None: kwargs = {} # get all pairs of populations pops = [(p, q) for p in self.pop_names for q in self.pop_names if p != q] return Visualization.plot_rates( times=list(t), rates=dict(zip( [f"{p[0]}->{p[1]}" for p in pops], np.array([[e.migration_rates[p] for p in pops] for e in self.get_epochs(t)]).T )), show=show, file=file, title=title, ylabel=ylabel, kwargs=kwargs, ax=ax )
[docs] def plot( self, t: np.ndarray = None, show: bool = True, file: str = None, ylabel: str = '$N_e, m_{ij}$', ax: 'plt.Axes' = None, title: str = 'Demography', kwargs: dict = None ) -> 'plt.Axes': """ Plot the demographic scenario. :param t: Times at which to plot the population sizes and migration rates. By default, we use 1000 time points between time 0 and 10. :param show: Whether to show the plot. :param file: File to save the plot to. :param ylabel: Label of the y-axis. :param ax: Axes object to plot to. :param title: Title of the plot. :param kwargs: Keyword arguments to pass to the plotting function. :return: Axes object. """ from matplotlib import pyplot as plt if t is None: t = np.linspace(0, 10, 1000) if kwargs is None: kwargs = {} if ax is None: _, ax = plt.subplots() self.plot_pop_sizes(t=t, show=False, ax=ax, title=title, ylabel=ylabel, kwargs=kwargs) self.plot_migration(t=t, show=show, file=file, ax=ax, title=title, ylabel=ylabel, kwargs=kwargs) return ax
[docs] class Epoch: """ Epoch of a demographic scenario with constant population sizes and migration rates. """ #: Start time of the epoch. start_time: float #: End time of the epoch. end_time: float #: Population sizes. pop_sizes: Dict[str, float] #: Migration rates. migration_rates: Dict[Tuple[str, str], float]
[docs] def __init__( self, start_time: float = 0, end_time: float = np.inf, pop_sizes: Dict[str, float] = None, migration_rates: Dict[Tuple[str, str], float] = None ): """ Initialize the epoch. :param start_time: Start time of the epoch. :param end_time: End time of the epoch. :param pop_sizes: Population sizes. By default, we have ``{'pop_0': 1}`. :param migration_rates: Migration rates. By default, we have zero migration rates between all populations. """ if pop_sizes is None: pop_sizes = {'pop_0': 1} if migration_rates is None: migration_rates = {} #: Start time of the epoch. self.start_time: float = start_time #: End time of the epoch. self.end_time: float = end_time #: Population sizes. self.pop_sizes: Dict[str, float] = pop_sizes.copy() #: Population names. self.pop_names: List[str] = sorted(list(self.pop_sizes.keys())) #: Number of populations. self.n_pops: int = len(self.pop_names) migration_rates = migration_rates.copy() # fill non-existing migration rates with zero for p in self.pop_sizes: for q in self.pop_sizes: if p != q and (p, q) not in migration_rates: migration_rates[(p, q)] = 0 #: Migration rates. self.migration_rates: Dict[Tuple[str, str], float] = migration_rates
@cached_property def tau(self) -> float: """ Time interval of the epoch. """ return self.end_time - self.start_time def __eq__(self, other): """ Compare epochs using their hash. :param other: The other epoch. :return: Whether the epochs are equal. """ return hash(self) == hash(other) def __hash__(self): """ Hash the epoch. Note that we do not include the start and end time, since they are not relevant for the state space created from the epoch. :return: Hash of the epoch. """ return hash(( tuple(self.pop_sizes.items()), tuple(self.migration_rates.items()) )) def __str__(self): """ String representation of the epoch. :return: String representation. """ string = ( f"Epoch(start_time={self.start_time:.4g}, " f"end_time={self.end_time:.4g}, " f"pop_sizes=({', '.join([f'{p}={s:.4g}' for p, s in self.pop_sizes.items()])})" ) if self.n_pops > 1: string += ( f", migration_rates=({', '.join([f'{p}->{q}={r:.4g}' for (p, q), r in self.migration_rates.items()])})" ) return string
[docs] def to_string(self): """ Alias for :meth:`__str__`. :return: String representation. """ return str(self)
[docs] class DemographicEvent(ABC): """ Base class for (discrete) demographic events. """ #: Start time of the event. start_time: float #: Population names. pop_names: List[str] @abstractmethod def _apply(self, epoch: Epoch): """ Apply the demographic event to the given epoch if applicable. :param epoch: Epoch. """ pass @abstractmethod def _broadcast(self, epoch: Epoch): """ Adjust the end time of the epoch to the next time at which the rate changes due to this event. :param epoch: Epoch. """ pass @staticmethod def _flatten( rates: Dict[Any, Dict[float, float]] ) -> (np.ndarray, Dict[float, Dict[Any, float]]): """ Flatten rates into a list of times and a list of rates. :param rates: Dictionary mapping key to dictionary mapping times to rates. :return: Array of times and dictionary mapping key to dictionary mapping population to rate. """ # get all unique times times_all = np.sort(np.unique(np.array([i for s in rates.values() for i in s], dtype=float))) # flattened list of migration rates new_rates: Dict[float, Dict[Any, float]] = defaultdict(lambda: {}) # loop over all times for t in times_all: # for each key for key, r in rates.items(): # if the time is in this population's times if t in r: # add rate new_rates[t][key] = r[t] return times_all, dict(new_rates)
[docs] class DiscreteDemographicEvent(DemographicEvent, ABC): """ Base class for discrete demographic events. """ #: Time at which the events occur in ascending order. times: np.ndarray def _broadcast(self, epoch: Epoch): """ Adjust the end time of the epoch to the next time at which the rate changes due to this event. :param epoch: Epoch. """ # times which are within the time interval times: np.ndarray = self.times[( (epoch.start_time < self.times) & (self.times <= epoch.end_time) & (self.times > 0) )] # if there are times within the interval # set the end time to the most recent time if len(times): epoch.end_time = times[0]
[docs] class DiscreteRateChanges(DiscreteDemographicEvent): """ Demographic event for discrete changes in population sizes and migration rates. """
[docs] def __init__( self, pop_sizes: Dict[str, Dict[float, float]] = None, migration_rates: Dict[Tuple[str, str], Dict[float, float]] = None ): """ Initialize the population size change. :param pop_sizes: Population sizes. Either a dictionary of the form `{pop_i: {time1: size1, time2: size2}}`, indexed by population name, or a list of dictionaries of the form `{time1: size1, time2: size2}` ordered by population index, or a single dictionary of the form `{time1: size1, time2: size2}` for a single population. :param migration_rates: Migration rates. A dictionary of the form `{(pop_i, pop_j): {time1: rate1, time2: rate2}}` of migration from population `pop_i` to population `pop_j` at time `time1` etc. """ if pop_sizes is None: pop_sizes = {} if migration_rates is None: migration_rates = {} if not isinstance(pop_sizes, dict): raise ValueError('Population sizes must be a dictionary.') if not isinstance(migration_rates, dict): raise ValueError('Migration rates must be a dictionary.') if len(pop_sizes) == 0 and len(migration_rates) == 0: raise ValueError('Either one population size or migration rate must be specified.') # make sure population sizes are positive for p, sizes in pop_sizes.items(): if any(s <= 0 for s in sizes.values()): raise ValueError(f'Population sizes must be positive at all times.') # initialize zero migration rates if None is given if migration_rates is None: migration_rates = {} elif not isinstance(migration_rates, dict): raise ValueError('Migration rates must be a dictionary.') #: Population names. self.pop_names: List[str] = sorted(list(set(pop_sizes.keys()).union( {p for k in migration_rates for p in k}))) #: Number of populations / demes. self.n_pops: int = len(self.pop_names) # flatten the population sizes and migration rates times: np.ndarray rates: Dict[float, Dict[Any, float]] times, rates = self._flatten(pop_sizes | migration_rates) # check that all times are non-negative if np.any(np.array(times) < 0): raise ValueError('All times must not be negative.') # check that all migration rates are non-negative if np.any(np.array([rates[k][t] for k in rates for t in migration_rates if t in rates[k]]) < 0): raise ValueError('Migration rates must not be negative at all times.') # check that all population sizes are positive if np.any(np.array([rates[k][t] for k in rates for t in pop_sizes if t in rates[k]]) <= 0): raise ValueError('Population sizes must be positive at all times.') #: Times at which the population size changes occur. self.times: np.ndarray = times #: Population sizes. self.pop_sizes: Dict[float, Dict[str, float]] = { t: {x: pops[x] for x in self.pop_names if x in pops if x in pops} for t, pops in rates.items() } #: Migration rates at each time. self.migration_rates: Dict[float, Dict[Tuple[str, str], float]] = { t: {(p, q): rates[t][(p, q)] for p in self.pop_names for q in self.pop_names if (p, q) in rates[t]} for t in rates } #: Start time of the event. self.start_time: float = self.times[0]
def _apply(self, epoch: Epoch): """ Apply the demographic event to the given epoch if applicable. :param epoch: Epoch. """ for t in self.times[(epoch.start_time <= self.times) & (self.times < epoch.end_time)]: epoch.pop_sizes |= self.pop_sizes[t] epoch.migration_rates |= self.migration_rates[t]
[docs] class PopSizeChanges(DiscreteRateChanges): """ Demographic event for changes in population size. """
[docs] def __init__(self, pop_sizes: Dict[str, Dict[float, float]]): """ Initialize the population size change. :param pop_sizes: Population sizes. A dictionary of the form `{pop_i: {time1: size1, time2: size2}}`. """ super().__init__(pop_sizes=pop_sizes)
[docs] class PopSizeChange(PopSizeChanges): """ Demographic event for a single change in population size. """
[docs] def __init__(self, pop: str, time: float, size: float): """ Initialize the population size change. :param pop: Population name. :param time: Time at which the population size changes. :param size: Population size. """ super().__init__({pop: {time: size}})
[docs] class MigrationRateChanges(DiscreteRateChanges): """ Demographic event for changes in migration rates. """
[docs] def __init__(self, rates: Dict[Tuple[str, str], Dict[float, float]]): """ Initialize the (backwards-time) migration rate change. :param rates: Migration rates. A dictionary of the form `{(pop_i, pop_j): {time1: rate1, time2: rate2}}` of migration from population `pop_i` to population `pop_j` at time `time1` etc. """ super().__init__(migration_rates=rates)
[docs] class MigrationRateChange(MigrationRateChanges): """ Demographic event for a single change in migration rate. """
[docs] def __init__(self, source: str, dest: str, time: float, rate: float): """ Initialize the (backwards-time) migration rate change. :param source: Source population name. :param dest: Destination population name. :param time: Time at which the migration rate changes. :param rate: Migration rate. """ super().__init__({(source, dest): {time: rate}})
[docs] class SymmetricMigrationRateChanges(MigrationRateChanges): """ Demographic event for changes in symmetric migration rates. """
[docs] def __init__(self, pops: Iterable[str], rate: Dict[float, float] | float): """ Initialize the (backwards-time) migration rate change. :param pops: Population names across which the migration rates change uniformly. :param rate: Migration rates. A dictionary of the form `{time1: rate1, time2: rate2}` of migration from population `pop_i` to population `pop_j` at time `time1` etc. or alternatively a single float if the migration rate is constant over time. """ if isinstance(rate, (float, int)): rate = {0: rate} rate = {(p, q): rate for p in pops for q in pops if p != q} super().__init__(rates=rate)
[docs] class PopulationSplit(DiscreteDemographicEvent): """ Demographic event for a population split (forward in time). This corresponds to population merger backwards in time. Since ``phasegen`` does not support deterministic lineage movement due to its inherent structure, we can model a population split by specifying a large unidirectional migration rate from the derived to the ancestral population. """
[docs] def __init__( self, time: float, derived: str | List[str], ancestral: str, multiplier: float = 100 ): """ Initialize the population split. :param time: Time of the split. :param derived: Derived populations from which all lineages move to the ancestral population. :param ancestral: Ancestral population to which all lineages move. :param multiplier: Migration rate multiplier. The migration rate from the derived to the ancestral population is set to the population size of the derived population times this multiplier. This value should be chosen large enough to ensure that the lineages move to the ancestral population *fast enough*. """ if isinstance(derived, str): derived = [derived] #: Time of the split. self.start_time: float = time #: Times at which the event occurs. self.times: np.ndarray = np.array([time]) #: Population names. self.pop_names: List[str] = sorted(derived + [ancestral]) #: Derived populations. self.derived: List[str] = derived #: Ancestral population. self.ancestral: str = ancestral #: Migration rate multiplier. self.multiplier: float = multiplier
def _apply(self, epoch: Epoch): """ Apply the demographic event to the given epoch if applicable. :param epoch: Epoch. """ # if epoch is contained in the event if epoch.start_time <= self.start_time < epoch.end_time: # specify high migration rate from derived to ancestral population for p in self.derived: epoch.migration_rates[(self.ancestral, p)] = epoch.pop_sizes[p] * self.multiplier # set all derived population sizes to zero # for p in self.derived: # epoch.pop_sizes[p] = 0 # set all migration rates to the derived populations to zero for p in self.derived: for q in epoch.pop_names: epoch.migration_rates[(p, q)] = 0
[docs] class DiscretizedDemographicEvent(DemographicEvent, ABC): """ Base class for discretized demographic events. """ pass
[docs] class DiscretizedRateChange(DiscretizedDemographicEvent): """ Demographic event for discretized rate changes of a single population or migration rate. """
[docs] def __init__( self, trajectory: Callable[[float], float], start_time: float, end_time: float = np.inf, pop: str | None = None, source: str | None = None, dest: str | None = None, step_size: float = 0.1 ): """ Initialize the population size change. :param trajectory: Trajectory function taking the time as argument and returning the rate. :param start_time: Start time of the event. :param end_time: End time of the event. :param pop: Population name or None if no population size changes. :param source: Source population name or None if no migration rate changes. :param dest: Destination population name or None if no migration rate changes. :param step_size: Step size used for the discretization. """ if pop is None and (source is None or dest is None): raise ValueError('Either pop or source_pop and dest_pop must be specified.') #: Population name. self.pop: str | None = pop #: Population names. self.pop_names: List[str] = sorted(list(p for p in {pop, source, dest} if p is not None)) #: Start time of the event. self.start_time: float = start_time #: End time of the event. self.end_time: float = end_time #: Trajectory function. self.trajectory: Callable[[float], float] = trajectory #: Step size used for the discretization. self.step_size: float = step_size #: Source population name. self.source_pop: str | None = source #: Destination population name. self.dest_pop: str | None = dest
def _broadcast(self, epoch: Epoch): """ Adjust the end time of the epoch to the next time at which the rate changes due to this event. :param epoch: Epoch. """ # return if there is no overlap if epoch.end_time < self.start_time or epoch.start_time > self.end_time: return # if this event starts after the epoch, we take the start time if self.start_time > epoch.start_time: epoch.end_time = self.start_time else: n_steps = np.ceil((epoch.start_time - self.start_time + 1e-10) / self.step_size) epoch.end_time = self.start_time + n_steps * self.step_size def _apply(self, epoch: Epoch): """ Apply the demographic event to the given epoch if applicable. :param epoch: Epoch. """ # if epoch is contained in the event if self.start_time <= epoch.start_time and epoch.end_time < self.end_time: rate_start = self.trajectory(epoch.start_time) rate_end = self.trajectory(epoch.end_time) rate = (rate_start + rate_end) / 2 if self.pop is None: epoch.migration_rates[(self.source_pop, self.dest_pop)] = rate else: epoch.pop_sizes[self.pop] = rate
[docs] class DiscretizedRateChanges(DiscretizedDemographicEvent): """ Demographic event for discretized rate changes of multiple populations or migration rates. """
[docs] def __init__( self, trajectory: Dict[Any, Callable[[float], float]], start_time: Dict[Any, float] | float, end_time: Dict[Any, float] | float = np.inf, step_size: float = 0.1 ): """ Initialize the population size change. :param trajectory: Trajectory functions taking the time as argument and returning the rate. :param start_time: Start times of the events. A single value or a dictionary mapping keys to values. :param end_time: End times of the events. :param step_size: Step size used for the discretization. """ #: Discretized rate change events. self.events = {} for k in trajectory: self.events[k] = DiscretizedRateChange( trajectory=trajectory[k], start_time=start_time[k] if isinstance(start_time, dict) else start_time, end_time=end_time[k] if isinstance(end_time, dict) else end_time, pop=k if isinstance(k, str) else None, source=k[0] if isinstance(k, tuple) else None, dest=k[1] if isinstance(k, tuple) else None, step_size=step_size ) #: Population names. self.pop_names: List[str] = sorted(list(set([p for e in self.events.values() for p in e.pop_names]))) #: Start time of the event. self.start_time: float = min([e.start_time for e in self.events.values()]) #: End time of the event. self.end_time: float = max([e.end_time for e in self.events.values()])
def _broadcast(self, epoch: Epoch): """ Adjust the end time of the epoch to the next time at which the rate changes due to this event. :param epoch: Epoch. """ for e in self.events.values(): e._broadcast(epoch) def _apply(self, epoch: Epoch): """ Apply the demographic event to the given epoch if applicable. :param epoch: Epoch. :return: Epoch. """ for e in self.events.values(): e._apply(epoch)
[docs] class ExponentialRateChanges(DiscretizedRateChanges): """ Demographic event for exponential rate changes of multiple populations or migration rates. """
[docs] def __init__( self, initial_rate: Dict[Any, float], growth_rate: Dict[Any, float] | float, start_time: Dict[Any, float] | float, end_time: Dict[Any, float] | float = np.inf, step_size: float = 0.1 ): """ Initialize the exponential growth. :param initial_rate: Initial rates. A dictionary mapping keys to values. Keys are either population names or tuples of population names for population sizes and migration rates, respectively. :param growth_rate: Exponential growth rates. A single value or a dictionary mapping keys to values. :param start_time: Start times of the growth. A single value or a dictionary mapping keys to values. :param end_time: End times of the growth. :param step_size: Step size used for the discretization. """ def get_trajectory(k: Any) -> Callable[[float], float]: """ Get the trajectory function for the given key. :param k: Key. :return: Trajectory function. """ g = growth_rate[k] if isinstance(growth_rate, dict) else growth_rate t0 = start_time[k] if isinstance(start_time, dict) else start_time x0 = initial_rate[k] if isinstance(initial_rate, dict) else initial_rate # return lambda and bind g, t0 and x0 into it # noinspection all return lambda t, g=g, t0=t0, x0=x0: x0 * np.exp(- g * (t - t0)) super().__init__( trajectory={k: get_trajectory(k) for k in initial_rate}, start_time=start_time, end_time=end_time, step_size=step_size )
[docs] class ExponentialPopSizeChanges(ExponentialRateChanges): """ Demographic event for exponential population size changes of multiple populations. """
[docs] def __init__( self, initial_size: Dict[str, float], growth_rate: Dict[str, float] | float, start_time: Dict[str, float] | float, end_time: Dict[str, float] | float = np.inf, step_size: float = 0.1 ): """ Initialize the exponential growth. :param initial_size: Initial population sizes. A dictionary mapping population names to sizes. :param growth_rate: Exponential growth rates. A single value or a dictionary mapping keys to values. :param start_time: Start times of the growth. A single value or a dictionary mapping keys to values. :param end_time: End times of the growth. :param step_size: Step size used for the discretization. """ super().__init__( initial_rate=initial_size, growth_rate=growth_rate, start_time=start_time, end_time=end_time, step_size=step_size )