Inference#

Inference class for gradient-based parameter inference with respect to a specified loss function, summary statistics, and a Coalescent distribution.

class Inference(bounds: Dict[str, Tuple[float, float]], coal: Callable[[...], Coalescent], loss: Callable[[Coalescent, Any], float], x0: Dict[str, float] = None, observation: Any = None, resample: Callable[[Any, Generator], Any] = None, n_runs: int = 10, n_bootstraps: int = 100, do_bootstrap: bool = False, parallelize: bool = False, pbar: bool = True, seed: int = None, cache: bool = True, opts: Dict = None, method_mle: str = 'L-BFGS-B')[source]#

Bases: Serializable

Gradient-based parameter inference with respect to a specified loss function, summary statistics, and a Coalescent distribution. The optimization is performed via the BFGS algorithm from scipy.

Note

TODO there are problems when pickling this object if is has already been unpickled previously.

default_opts = {}#

Default options passed to the optimization algorithm. See https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb

__init__(bounds: Dict[str, Tuple[float, float]], coal: Callable[[...], Coalescent], loss: Callable[[Coalescent, Any], float], x0: Dict[str, float] = None, observation: Any = None, resample: Callable[[Any, Generator], Any] = None, n_runs: int = 10, n_bootstraps: int = 100, do_bootstrap: bool = False, parallelize: bool = False, pbar: bool = True, seed: int = None, cache: bool = True, opts: Dict = None, method_mle: str = 'L-BFGS-B')[source]#

Initialize the class with the provided parameters.

Parameters:
  • bounds (Dict[str, Tuple[float, float]]) – Dictionary of tuples representing the bounds for each parameter in x0.

  • coal (Callable[..., Coalescent]) – Callback returning the configured coalescent distribution on which the inference is based on. The parameters specified in x0 and bounds are passed as keyword arguments.

  • loss (Callable[[Coalescent, Any], float]) – The loss function. This function must return a single numerical value that is to be minimized. It receives as first argument the coalescent distribution returned by the dist callback, and as second argument the observation passed to the observation argument (if any).

  • x0 (Dict[str, float]) – Dictionary of initial numeric guesses for parameters to optimize.

  • observation (Any) – The observed summary statistic the inference is based on. This is passed as second argument to the loss function, and is only required if you want to use automatic bootstrapping.

  • resample (Callable[[Any, Generator], Any]) – Callback that is used to resample the observation. This is required for automatic bootstrapping. The resample function must accept the observation as first argument and a random number generator as second argument, and must return a resampled observation.

  • n_runs (int) – Number of independent optimization runs.

  • n_bootstraps (int) – Number of bootstrap replicates.

  • do_bootstrap (bool) – Whether to perform automatic bootstrapping.

  • parallelize (bool) –

    Whether to parallelize the computations across available CPU cores.

    Note

    Parallelization across multiple CPU cores is not always faster than single-threaded execution. It can also lead to hanging processes due to pickling issues, depending on how the provided callback function is defined. For more scalable parallelization, consider using the create_run() and create_bootstrap() methods to create new Inference objects that can be run independently, and whose results can be merged subsequently.

  • pbar (bool) – Whether to show a progress bar.

  • seed (int) – Seed for the random number generator.

  • cache (bool) – Whether to cache the state spaces across the given optimization iterations given that they are equivalent. The can significantly speed up the optimization as we do not require to recompute the complete state spaces for each iteration. This only leads to performance improvements if optimizing demographic parameters such as population sizes or migration rates.

  • opts (Dict) – Additional options passed to the optimization algorithm. See https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb

  • method_mle (str) – Method to use for optimization. See scipy.optimize.minimize for available methods.

bounds: Dict[str, Tuple[float, float]]#

Dictionary of tuples representing the bounds for each parameter in x0.

coal: Callable[..., Coalescent]#

Callback returning the configured coalescent distribution.

loss: Callable[[Coalescent, Any], float]#

Loss function.

observation: Any#

The observed summary statistic the inference is based on.

resample: Callable[[Any, Generator], Any]#

Callback that is used to resample the observation.

n_runs: int#

Number of optimization runs.

n_bootstraps: int#

Number of bootstrap replicates.

do_bootstrap: bool#

Whether to perform automatic bootstrapping.

parallelize: bool#

Whether to parallelize the computations.

pbar: bool#

Whether to show a progress bar.

seed: int | None#

Seed for the random number generator.

cache: bool#

Whether to cache the state spaces

opts: Dict#

Optimization options

method_mle: str#

Optimization method

result: OptimizeResult | None#

Optimization result

params_inferred: Dict#

Inferred parameters.

loss_inferred: float | None#

Loss of the best optimization run

dist_inferred: Coalescent | None#

Coalescent distribution of best run

bootstraps: DataFrame#

Bootstrap parameters

runs: DataFrame#

Initial optimization runs

property param_names: List[str]#

Get the names of the parameters.

property x0: Dict[str, float]#

Initial parameters.

get_coal(**kwargs)[source]#

Get the (possibly cached) coalescent distribution.

Parameters:

kwargs – Keyword arguments passed to the callback specified as ``dist`.

Return type:

Coalescent

Returns:

Coalescent distribution.

run()[source]#

Execute the optimization.

bootstrap()[source]#

Perform bootstrapping.

Returns:

Bootstrap replicates.

plot_bootstraps(title: str | List[str] = None, show: bool = True, file: str = None, subplots: bool = True, kind: Literal['hist', 'kde'] = 'hist', ax: plt.Axes | List[plt.Axes] = None, kwargs: dict = None)[source]#

Plot bootstrapped parameters.

Parameters:
  • title (Union[str, List[str]]) – Title or list of titles.

  • show (bool) – Whether to show the plot.

  • file (str) – File to save the plot.

  • subplots (bool) – Whether to plot subplots.

  • kind (Literal['hist', 'kde']) – Kind of plot. Either ‘hist’ or ‘kde’.

  • ax (Union[plt.Axes, List[plt.Axes]]) – Axes or list of axes.

  • kwargs (dict) – Additional keyword arguments passed to the pandas plot function.

Return type:

Union[plt.Axes, List[plt.Axes]]

Returns:

Axes or list of axes.

plot_demography(t: ndarray = None, include_bootstraps: bool = True, show: bool = True, file: str = None, kwargs: dict = None, ax: List[plt.Axes] | None = None)[source]#

Plot inferred demography.

Parameters:
  • t (ndarray) – Time points. By default, 100 time points are used that extend from 0 to the 99th percentile of the tree height distribution.

  • include_bootstraps (bool) – Whether to include bootstraps.

  • show (bool) – Whether to show the plot.

  • file (str) – File to save the plot.

  • kwargs (dict) – Additional keyword arguments passed to the plot function.

  • ax (Optional[List[plt.Axes]]) – List of axes to plot on.

Return type:

List[plt.Axes]

Returns:

List of axes.

plot_pop_sizes(t: ndarray = None, show: bool = True, include_bootstraps: bool = True, file: str = None, kwargs: dict = None, ax: plt.Axes | None = None)[source]#

Plot inferred population sizes.

Parameters:
  • t (ndarray) – Time points. By default, 100 time points are used that extend from 0 to the 99th percentile of the tree height distribution.

  • show (bool) – Whether to show the plot.

  • include_bootstraps (bool) – Whether to include bootstraps.

  • file (str) – File to save the plot.

  • kwargs (dict) – Additional keyword arguments passed to the plot function.

  • ax (Optional[plt.Axes]) – List of axes to plot on.

Return type:

plt.Axes

Returns:

Axes.

plot_migration(t: ndarray = None, show: bool = True, file: str = None, include_bootstraps: bool = True, kwargs: dict = None, ax: plt.Axes | None = None)[source]#

Plot inferred migration rates.

Parameters:
  • t (ndarray) – Time points. By default, 100 time points are used that extend from 0 to the 99th percentile of the tree height distribution.

  • show (bool) – Whether to show the plot.

  • file (str) – File to save the plot.

  • include_bootstraps (bool) – Whether to include bootstraps.

  • kwargs (dict) – Additional keyword arguments passed to the plot function.

  • ax (Optional[plt.Axes]) – List of axes to plot on.

Return type:

plt.Axes

Returns:

Axes.

create_run(x0: Dict[str, float] = None)[source]#

Create a new Inference object which can be run independently. This is useful when parallelizing runs on a cluster. You can add performed runs by using the add_run method.

Parameters:

x0 (Dict[str, float]) – Initial parameters.

Return type:

Inference

Returns:

Inference object.

add_run(inference: Inference)[source]#

Merge the main optimization result from another Inference object into the current Inference object. We only store the result of the run with the lowest loss.

Parameters:

inference (Inference) – Inference object.

Raises:

RuntimeError – If the main optimization has not been run yet.

add_runs(inferences: Iterable[Inference])[source]#

Merge the main optimization results from an iterable of Inference objects with the current Inference object. We only store the result of the run with the lowest loss.

Parameters:

inferences (Iterable[Inference]) – Iterable of Inference objects.

create_bootstrap(n_runs: int = 1)[source]#

Resample the observation and return a new Inference object with the resampled observation. This is useful when parallelizing bootstraps on a cluster. You can add performed bootstraps by using the add_bootstrap method.

Return type:

Inference

Returns:

Resampled observation.

add_bootstrap(bootstrap: Inference)[source]#

Add main optimization result from another Inference object as a bootstrap to the current Inference object.

Parameters:

bootstrap (Inference) – Either an Inference object or a dictionary of inferred parameters.

Raises:

RuntimeError – If the main optimization has not been run yet.

add_bootstraps(data: Iterable[Inference] | Iterable[Dict[str, float]])[source]#

Add bootstraps from an iterable of Inference objects.

Parameters:

data (Union[Iterable[Inference], Iterable[Dict[str, float]]]) – Iterable of Inference objects or dictionaries of inferred parameters.

classmethod from_file(file: str, classes=None)#

Load object from file.

Parameters:
  • classes – Classes to be used for unserialization

  • file (str) – File to load from

Return type:

Self

classmethod from_json(json: str, classes=None)#

Unserialize object.

Parameters:
  • classes – Classes to be used for deserialization.

  • json (str) – JSON string

Return type:

Self

to_file(file: str)#

Save object to file (JSON).

Parameters:

file (str) – File path.

to_json()#

Serialize object.

Return type:

str

Returns:

JSON string

class WeightedLoss(weights: Dict[str, float], n_max: int | None = 100)[source]#

Bases: object

Weigh components of the loss function based on the average of the observed and modelled values.

__init__(weights: Dict[str, float], n_max: int | None = 100)[source]#

Initialize the class with the provided parameters.

Parameters:
  • weights (Dict[str, float]) – Dictionary of weights for each component of the loss function.

  • n_max (int | None) – Maximum recent values to consider for the average. Use None to consider all values.

weights: Dict[str, float]#

Weights for each component of the loss function.

n_max: int#

Maximum recent values to consider for the average.

keys: List[str]#

Keys of the weights.

cache: Dict[str, ndarray]#

Cached values.

property average#

Average of the cached values.

compute(loss: Dict[str, float])[source]#

Compute the weighted loss.

Parameters:

loss (Dict[str, float]) – Dictionary of loss values for each component of the loss function.

Return type:

float

Returns:

Weighted loss.