Source code for phasegen.spectrum

"""
Classes for working with the site-frequency spectrum (SFS) and 2-SFS.
"""

import copy
import logging
from typing import Dict, Iterable, Iterator, Sequence, Tuple

import jsonpickle
import numpy as np
# noinspection PyUnresolvedReferences
from fastdfe import Spectrum, Spectra

logger = logging.getLogger('phasegen').getChild('spectrum')


[docs] class SFS(Spectrum): """ A site-frequency spectrum. """ pass
[docs] class SFS2(Iterable): """ A 2-dimensional site-frequency spectrum. """
[docs] def __init__(self, data: np.ndarray | list): """ Construct from data matrix. :param data: """ data = np.array(data).copy() if data.ndim != 2: raise ValueError('Data has to be 2-dimensional') if data.shape[0] != data.shape[1]: raise ValueError('Matrix has to be square.') self.n = data.shape[0] # width self.w = self.n // 2 + 1 if self.n % 2 == 1 else self.n // 2 self.data = data
[docs] def to_file(self, file): """ Save to file (in JSON format). :param file: File path. """ with open(file, 'w') as f: f.write(self.to_json())
[docs] def to_json(self) -> str: """ Convert data to JSON string. :return: JSON string """ obj = copy.deepcopy(self) # convert numpy array to list obj.data = obj.data.tolist() return jsonpickle.encode(obj)
[docs] @staticmethod def from_file(file: str) -> 'SFS2': """ Load from file. :param file: File path. :return: SFS2 """ with open(file, 'r') as f: return SFS2.from_json(f.read())
[docs] @staticmethod def from_json(json: str) -> 'SFS2': """ Load from JSON string. :param json: JSON string. :return: SFS2 """ obj = jsonpickle.decode(json) # convert list to numpy array obj.data = np.array(obj.data) return obj
[docs] def is_folded(self) -> bool: """ Check if the 2-SFS is folded. :return: Whether the 2-SFS is folded. """ return np.all(self.data == self.fold().data)
def __add__(self, other) -> 'SFS2': """ Add two 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self + other.data return SFS2(self.data + other) def __sub__(self, other) -> 'SFS2': """ Subtract two 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self - other.data return SFS2(self.data - other) def __mul__(self, other) -> 'SFS2': """ Multiply 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self * other.data return SFS2(self.data * other) def __floordiv__(self, other) -> 'SFS2': """ Divide 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self // other.data return SFS2(self.data // other) def __truediv__(self, other) -> 'SFS2': """ Divide 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self / other.data return SFS2(self.data / other) def __iter__(self) -> Iterator: """ Iterate over entries. :return: Iterator """ return self.data.__iter__() def __pow__(self, power) -> 'SFS2': """ Power operator. :param power: exponent :return: Spectrum """ return SFS2(self.data ** power)
[docs] def fold(self) -> 'SFS2': """ Fold 2-SFS by adding up ``i`` and ``n - i`` for both axes. Node that this only make sense for counts or frequencies. :return: Folded 2-SFS. """ data = self.data.copy() for _ in range(2): # compute left and right half and merge them left = np.concatenate((data[:self.w], np.zeros((self.n - self.w, self.n)))) right = np.concatenate((data[self.w:][::-1], np.zeros((self.w, self.n)))) # add parts and rotate data = (left + right).T return SFS2(data)
[docs] def copy(self) -> 'SFS2': """ Create deep copy. :return: Deep copy. """ return copy.deepcopy(self)
[docs] def symmetrize(self) -> 'SFS2': """ Symmetric SFS so that ``i, j`` and ``j, i`` are the same. :return: Symmetric 2-SFS. """ return SFS2((self.data + self.data.T) / 2)
[docs] def fill_monomorphic(self, fill_value=np.nan) -> 'SFS2': """ Remote the diagonal entries of the given array. :param fill_value: Value to fill diagonal entries with. :return: 2-SFS """ other = self.copy() other.data[:1, :] = fill_value other.data[-1:, :] = fill_value other.data[:, :1] = fill_value other.data[:, -1] = fill_value return other
[docs] def plot( self, ax: 'plt.Axes' = None, title: str = None, max_abs: float = None, log_scale: bool = False, cbar_kws: Dict = None, show: bool = True, ) -> 'plt.Axes': """ Plot as a heatmap. :param title: Title of the plot. :param ax: Axes to plot on. :param max_abs: Maximum absolute value to plot. :param log_scale: Use log scale. :param cbar_kws: Keyword arguments for color bar. :param show: Whether to show the plot. :return: Axes. """ import matplotlib.pyplot as plt from matplotlib.colors import SymLogNorm import seaborn as sns if self.n < 3: logger.warning('Nothing to plot.') return plt.gca() if cbar_kws is None: cbar_kws = dict(pad=0.05) if max_abs is None: max_abs = self.get_max_abs() or 1 # remove monomorphic sites data = self.data[1:-1, 1:-1] # truncate data if folded if self.is_folded(): data = data[:self.w - 1, :self.w - 1] # plot heatmap using a symmetric log norm ax = sns.heatmap( data, norm=SymLogNorm( linthresh=max_abs / 10, vmin=-max_abs, vmax=max_abs ), cmap='PuOr_r', cbar_kws=cbar_kws, ax=ax ) # invert y-axis and remove ticks ax.invert_yaxis() ax.axis('square') if log_scale: ax.set_xscale('log', base=1.001) ax.set_yscale('log', base=1.001) ax.set_xticks(ax.get_yticks()) ax.set_xticklabels([str(int(label + 1)) for label in ax.get_xticks()]) ax.set_yticklabels([str(int(label + 1)) for label in ax.get_yticks()]) # remove confusing color bar ticks ax.collections[0].colorbar.ax.tick_params(size=0) # add frame around plot for _, spine in ax.spines.items(): spine.set_visible(True) spine.set_edgecolor('grey') if title is not None: ax.set_title(title) if show: plt.show() return ax
[docs] def plot_surface( self, ax: 'plt.Axes' = None, title: str = None, max_abs: float = None, vmin: float = None, vmax: float = None, show: bool = True, ) -> 'plt.Axes': """ Plot as a surface. :param title: :param ax: Axes to plot on. :param max_abs: Maximum absolute value to plot. :param vmin: Minimum value to plot. :param vmax: Maximum value to plot. :param show: Whether to show the plot. :return: Axes. """ import matplotlib.pyplot as plt from matplotlib.colors import SymLogNorm if self.n < 3: logger.warning('Nothing to plot.') return plt.gca() if max_abs is None: max_abs = self.get_max_abs() or 1 # remove monomorphic sites data = self.data[1:-1, 1:-1] # truncate data if folded if self.is_folded(): data = data[:self.w - 1, :self.w - 1] x = np.arange(1, data.shape[0] + 1) y = np.arange(1, data.shape[0] + 1) x_grid, y_grid = np.meshgrid(x, y) if ax is None: _, ax = plt.subplots(subplot_kw={"projection": "3d"}) # vmin and vmax don't seem to work here ax.plot_surface( x_grid, y_grid, data, cmap='PuOr_r', vmin=vmin, vmax=vmax, norm=SymLogNorm( linthresh=max_abs / 10, vmin=-max_abs, vmax=max_abs ) ) if title is not None: ax.set_title(title) if show: plt.show() return ax
[docs] def mask_diagonal(self, fill_value=np.nan) -> 'SFS2': """ Mask both the primary and secondary diagonal entries of the 2-SFS matrix. The primary diagonal runs from the top-left to the bottom-right, and the secondary diagonal runs from the top-right to the bottom-left. :param fill_value: The value to fill the diagonal entries with. :return: A new SFS2 object with both diagonals masked. """ data = self.data.copy() np.fill_diagonal(data, fill_value) data = np.fliplr(data) np.fill_diagonal(data, fill_value) data = np.fliplr(data) return SFS2(data)
[docs] def get_max_abs(self) -> float: """ Get the maximum absolute entry of the 2-SFS matrix. :return: The maximum absolute entry. """ return np.nanmax(np.abs(self.data))
[docs] def mask_upper(self, fill_value=np.nan) -> 'SFS2': """ Mask the upper triangular entries of the 2-SFS matrix. :param fill_value: The value to fill the upper triangular entries with. :return: A new SFS2 object with upper triangular entries masked. """ data = self.copy().data data[np.tril(np.ones_like(data, dtype=bool), k=-1)] = fill_value return SFS2(data)
[docs] class TwoLocusSFS(SFS2): """ The two-locus site-frequency spectrum under recombination: a square matrix whose entry ``(i, j)`` is the expected product of the branch length subtending ``i`` samples at locus 0 and ``j`` samples at locus 1, for two loci separated by recombination rate ``r``. It interpolates between the within-tree cross-moment of the SFS at ``r = 0`` (fully linked, equal to ``Coalescent.sfs.cov`` plus the outer product of the marginal means) and the outer product of the marginal SFS as ``r → ∞`` (independent loci). It shares the container machinery of :class:`SFS2` (plotting, folding, arithmetic, serialization). """ pass
[docs] class JointSFS(Iterable): """ A joint (multi-population) site-frequency spectrum. The data is a ``P``-dimensional array of shape ``(n_0 + 1, ..., n_{P-1} + 1)`` where ``P`` is the number of populations and entry ``(k_0, ..., k_{P-1})`` corresponds to branches subtending ``k_p`` samples from population ``p``. For two populations this is a 2-dimensional array (analogous to but generally rectangular, unlike the square :class:`SFS2`); for three populations it is a 3-dimensional array, and so on. """
[docs] def __init__(self, data: np.ndarray | list): """ Construct from a data array. :param data: A ``P``-dimensional array. """ data = np.asarray(data) if data.ndim < 1: raise ValueError('Data has to be at least 1-dimensional.') #: The joint SFS array. self.data: np.ndarray = data
@property def n_pops(self) -> int: """ Number of populations (dimensions of the joint SFS). """ return self.data.ndim @property def shape(self) -> Tuple[int, ...]: """ Shape of the joint SFS array. """ return self.data.shape def __array__(self, dtype=None) -> np.ndarray: """ Numpy array interface so the joint SFS can be used directly in numpy operations. :param dtype: Optional dtype. :return: The underlying array. """ return self.data if dtype is None else self.data.astype(dtype) def __iter__(self) -> Iterator: """ Iterate over the first axis of the joint SFS. :return: Iterator. """ return self.data.__iter__() def __getitem__(self, item): """ Index into the joint SFS array. :param item: Index. :return: Indexed value or sub-array. """ return self.data[item] def __add__(self, other) -> 'JointSFS': return JointSFS(self.data + (other.data if isinstance(other, JointSFS) else other)) def __sub__(self, other) -> 'JointSFS': return JointSFS(self.data - (other.data if isinstance(other, JointSFS) else other)) def __mul__(self, other) -> 'JointSFS': return JointSFS(self.data * (other.data if isinstance(other, JointSFS) else other)) def __truediv__(self, other) -> 'JointSFS': return JointSFS(self.data / (other.data if isinstance(other, JointSFS) else other)) def __pow__(self, power) -> 'JointSFS': return JointSFS(self.data ** power)
[docs] def copy(self) -> 'JointSFS': """ Create a deep copy. :return: Deep copy. """ return copy.deepcopy(self)
[docs] def marginalize(self, pops: Sequence[int]) -> 'JointSFS': """ Marginalize the joint SFS onto a subset of populations by summing over the other populations. This is useful for example to obtain a 2-dimensional view of a higher-dimensional joint SFS. :param pops: The population indices to keep, in the desired axis order. :return: A joint SFS over the specified populations. """ keep = tuple(int(p) for p in pops) if any(p < 0 or p >= self.n_pops for p in keep): raise ValueError(f'Population indices must be in [0, {self.n_pops - 1}].') drop = tuple(i for i in range(self.n_pops) if i not in keep) data = self.data.sum(axis=drop) if drop else self.data # reorder the remaining axes (which are in ascending order) to match the requested order order = [sorted(keep).index(p) for p in keep] return JointSFS(np.transpose(data, order))
[docs] def plot( self, pops: Tuple[int, int] = (0, 1), ax: 'plt.Axes' = None, title: str = None, log_scale: bool = False, mask_monomorphic: bool = True, cbar_kws: Dict = None, show: bool = True, ) -> 'plt.Axes': """ Plot the joint SFS as a 2-dimensional heatmap. For more than two populations, the joint SFS is first marginalized onto the two requested populations (summing over the others). :param pops: The two population indices to plot (y-axis, x-axis). :param ax: Axes to plot on. :param title: Title of the plot. :param log_scale: Whether to use a logarithmic color scale. :param mask_monomorphic: Whether to mask the monomorphic corners (all-zero and all-derived). :param cbar_kws: Keyword arguments for the color bar. :param show: Whether to show the plot. :return: Axes. """ import matplotlib.pyplot as plt from matplotlib.colors import LogNorm import seaborn as sns if len(pops) != 2: raise ValueError('Exactly two populations must be specified for a 2-dimensional plot.') # reduce to the two requested populations data = (self.marginalize(pops) if self.n_pops > 2 else self).data.astype(float).copy() if data.ndim != 2: raise ValueError('Plotting requires a 2-dimensional (marginalized) joint SFS.') if mask_monomorphic: data[0, 0] = np.nan data[-1, -1] = np.nan if cbar_kws is None: cbar_kws = dict(pad=0.05) # create a fresh 2-D axes if none is given (so we never draw onto a leftover 3-D axes from plot_surface) if ax is None: _, ax = plt.subplots() ax = sns.heatmap( data, norm=LogNorm() if log_scale else None, cmap='viridis', cbar_kws=cbar_kws, ax=ax ) # put the origin at the bottom left ax.invert_yaxis() ax.set_xlabel(f'allele count pop {pops[1]}') ax.set_ylabel(f'allele count pop {pops[0]}') # square cells, a grey frame, and unobtrusive color bar ticks (as for the 2-SFS plot) ax.set_aspect('equal') ax.collections[0].colorbar.ax.tick_params(size=0) for spine in ax.spines.values(): spine.set_visible(True) spine.set_edgecolor('grey') if title is not None: ax.set_title(title) if show: plt.show() return ax
[docs] def plot_surface( self, pops: Tuple[int, int] = (0, 1), ax: 'plt.Axes' = None, title: str = None, log_scale: bool = False, mask_monomorphic: bool = True, cmap: str = 'viridis', show: bool = True, ) -> 'plt.Axes': """ Plot the joint SFS as a 3-dimensional surface. For more than two populations, the joint SFS is first marginalized onto the two requested populations (summing over the others). :param pops: The two population indices to plot (y-axis, x-axis). :param ax: Axes to plot on. :param title: Title of the plot. :param log_scale: Whether to use a logarithmic color scale. :param mask_monomorphic: Whether to mask the monomorphic corners (all-zero and all-derived). :param cmap: The colormap. :param show: Whether to show the plot. :return: Axes. """ import matplotlib.pyplot as plt from matplotlib.colors import LogNorm if len(pops) != 2: raise ValueError('Exactly two populations must be specified for a surface plot.') # reduce to the two requested populations data = (self.marginalize(pops) if self.n_pops > 2 else self).data.astype(float).copy() if data.ndim != 2: raise ValueError('Plotting requires a 2-dimensional (marginalized) joint SFS.') if mask_monomorphic: data[0, 0] = np.nan data[-1, -1] = np.nan # allele-count grid (0..n_p) for each of the two populations x_grid, y_grid = np.meshgrid(np.arange(data.shape[1]), np.arange(data.shape[0])) if ax is None: _, ax = plt.subplots(subplot_kw={'projection': '3d'}) ax.plot_surface(x_grid, y_grid, data, cmap=cmap, norm=LogNorm() if log_scale else None) ax.set_xlabel(f'allele count pop {pops[1]}') ax.set_ylabel(f'allele count pop {pops[0]}') ax.set_zlabel('branch length') if title is not None: ax.set_title(title) if show: plt.show() return ax
[docs] def to_file(self, file: str): """ Save to file (in JSON format). :param file: File path. """ with open(file, 'w') as f: f.write(self.to_json())
[docs] def to_json(self) -> str: """ Convert to a JSON string. :return: JSON string. """ obj = copy.deepcopy(self) # convert numpy array to list obj.data = obj.data.tolist() return jsonpickle.encode(obj)
[docs] @staticmethod def from_file(file: str) -> 'JointSFS': """ Load from file. :param file: File path. :return: JointSFS """ with open(file, 'r') as f: return JointSFS.from_json(f.read())
[docs] @staticmethod def from_json(json: str) -> 'JointSFS': """ Load from a JSON string. :param json: JSON string. :return: JointSFS """ obj = jsonpickle.decode(json) # convert list back to numpy array obj.data = np.array(obj.data) return obj