Source code for phasegen.spectrum

"""
SFS and 2-SFS classes.
"""

import copy
import logging
from typing import Dict, Iterable, Iterator

import jsonpickle
import numpy as np
# noinspection PyUnresolvedReferences
from fastdfe import Spectrum, Spectra

logger = logging.getLogger('phasegen').getChild('spectrum')


[docs] class SFS(Spectrum): """ A site-frequency spectrum. """ pass
[docs] class SFS2(Iterable): """ A 2-dimensional site-frequency spectrum. """
[docs] def __init__(self, data: np.ndarray | list): """ Construct from data matrix. :param data: """ data = np.array(data).copy() if data.ndim != 2: raise ValueError('Data has to be 2-dimensional') if data.shape[0] != data.shape[1]: raise ValueError('Matrix has to be square.') self.n = data.shape[0] # width self.w = self.n // 2 + 1 if self.n % 2 == 1 else self.n // 2 self.data = data
[docs] def to_file(self, file): """ Save to file (in JSON format). :param file: File path. """ with open(file, 'w') as f: f.write(self.to_json())
[docs] def to_json(self) -> str: """ Convert data to JSON string. :return: JSON string """ obj = copy.deepcopy(self) # convert numpy array to list obj.data = obj.data.tolist() return jsonpickle.encode(obj)
[docs] @staticmethod def from_file(file: str) -> 'SFS2': """ Load from file. :param file: File path. :return: SFS2 """ with open(file, 'r') as f: return SFS2.from_json(f.read())
[docs] @staticmethod def from_json(json: str) -> 'SFS2': """ Load from JSON string. :param json: JSON string. :return: SFS2 """ obj = jsonpickle.decode(json) # convert list to numpy array obj.data = np.array(obj.data) return obj
[docs] def is_folded(self) -> bool: """ Check if the 2-SFS is folded. :return: Whether the 2-SFS is folded. """ return np.all(self.data == self.fold().data)
def __add__(self, other) -> 'SFS2': """ Add two 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self + other.data return SFS2(self.data + other) def __sub__(self, other) -> 'SFS2': """ Subtract two 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self - other.data return SFS2(self.data - other) def __mul__(self, other) -> 'SFS2': """ Multiply 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self * other.data return SFS2(self.data * other) def __floordiv__(self, other) -> 'SFS2': """ Divide 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self // other.data return SFS2(self.data // other) def __truediv__(self, other) -> 'SFS2': """ Divide 2-SFS. :param other: :return: """ if isinstance(other, SFS2): return self / other.data return SFS2(self.data / other) def __iter__(self) -> Iterator: """ Iterate over entries. :return: Iterator """ return self.data.__iter__() def __pow__(self, power) -> 'SFS2': """ Power operator. :param power: exponent :return: Spectrum """ return SFS2(self.data ** power)
[docs] def fold(self) -> 'SFS2': """ Fold 2-SFS by adding up ``i`` and ``n - i`` for both axes. Node that this only make sense for counts or frequencies. :return: Folded 2-SFS. """ data = self.data.copy() for _ in range(2): # compute left and right half and merge them left = np.concatenate((data[:self.w], np.zeros((self.n - self.w, self.n)))) right = np.concatenate((data[self.w:][::-1], np.zeros((self.w, self.n)))) # add parts and rotate data = (left + right).T return SFS2(data)
[docs] def copy(self) -> 'SFS2': """ Create deep copy. :return: Deep copy. """ return copy.deepcopy(self)
[docs] def symmetrize(self) -> 'SFS2': """ Symmetric SFS so that ``i, j`` and ``j, i`` are the same. :return: Symmetric 2-SFS. """ return SFS2((self.data + self.data.T) / 2)
[docs] def fill_monomorphic(self, fill_value=np.nan) -> 'SFS2': """ Remote the diagonal entries of the given array. :param fill_value: Value to fill diagonal entries with. :return: 2-SFS """ other = self.copy() other.data[:1, :] = fill_value other.data[-1:, :] = fill_value other.data[:, :1] = fill_value other.data[:, -1] = fill_value return other
[docs] def plot( self, ax: 'plt.Axes' = None, title: str = None, max_abs: float = None, log_scale: bool = False, cbar_kws: Dict = None, fill_diagonal_entries: bool = False, show: bool = True, ) -> 'plt.Axes': """ Plot as a heatmap. :param title: Title of the plot. :param ax: Axes to plot on. :param max_abs: Maximum absolute value to plot. :param log_scale: Use log scale. :param cbar_kws: Keyword arguments for color bar. :param fill_diagonal_entries: Fill diagonal entries. :param show: Whether to show the plot. :return: Axes. """ import matplotlib.pyplot as plt from matplotlib.colors import SymLogNorm import seaborn as sns if self.n < 3: logger.warning('Nothing to plot.') return plt.gca() if cbar_kws is None: cbar_kws = dict(pad=0.05) data = self.data.copy() # remove monomorphic entries data = self._remove_monomorphic(data) # mask diagonal entries if fill_diagonal_entries: data = self._fill_diagonals(data) # truncate data if folded if self.is_folded(): data = data[:self.w - 1, :self.w - 1] # determine colorbar bounds if not specified if max_abs is None: max_abs = self._get_max_abs_entry(data) # plot heatmap using a symmetric log norm ax = sns.heatmap( data, norm=SymLogNorm( linthresh=max_abs / 10, vmin=-max_abs, vmax=max_abs ), cmap='PuOr_r', cbar_kws=cbar_kws, ax=ax ) # invert y-axis and remove ticks ax.invert_yaxis() ax.axis('square') if log_scale: ax.set_xscale('log', base=1.001) ax.set_yscale('log', base=1.001) ax.set_xticklabels(range(1, len(data) + 1)) ax.set_yticklabels(range(1, len(data) + 1)) # remove confusing color bar ticks ax.collections[0].colorbar.ax.tick_params(size=0) # add frame around plot for _, spine in ax.spines.items(): spine.set_visible(True) spine.set_edgecolor('grey') if title is not None: ax.set_title(title) if show: plt.show() return ax
[docs] def plot_surface( self, ax: 'plt.Axes' = None, title: str = None, max_abs: float = None, vmin: float = None, vmax: float = None, fill_diagonal_entries: bool = False, fill_value: float = np.nan, log_scale: bool = False, show: bool = True, ) -> 'plt.Axes': """ Plot as a surface. :param title: :param ax: Axes to plot on. :param max_abs: Maximum absolute value to plot. :param vmin: Minimum value to plot. :param vmax: Maximum value to plot. :param fill_diagonal_entries: Fill diagonal entries. :param fill_value: Value to fill diagonal entries with. :param log_scale: Use log scale. :param show: Whether to show the plot. :return: Axes. """ import matplotlib.pyplot as plt from matplotlib.colors import SymLogNorm if self.n < 3: logger.warning('Nothing to plot.') return plt.gca() data = self.data.copy() # remove monomorphic entries data = self._remove_monomorphic(data) # mask diagonal entries if fill_diagonal_entries: data = self._fill_diagonals(data, fill_value) # truncate data if folded if self.is_folded(): data = data[:self.w - 1, :self.w - 1] # determine color bar bounds if not specified if max_abs is None: max_abs = self._get_max_abs_entry(data) or 1 x = np.arange(1, data.shape[0] + 1) y = np.arange(1, data.shape[0] + 1) x_grid, y_grid = np.meshgrid(x, y) if ax is None: _, ax = plt.subplots(subplot_kw={"projection": "3d"}) # vmin and vmax don't seem to work here ax.plot_surface( x_grid, y_grid, data, cmap='PuOr_r', vmin=vmin, vmax=vmax, norm=SymLogNorm( linthresh=max_abs / 10, vmin=-max_abs, vmax=max_abs ) ) # if log_scale: # ax.yaxis.set_scale('log') # ax.set_yscale('log', base=1.001) if title is not None: ax.set_title(title) if show: plt.show() return ax
@staticmethod def _remove_monomorphic(data: np.ndarray) -> np.ndarray: """ Remove monomorphic sites from given 2-SFS matrix. :return: The data without monomorphic sites. """ return data[1:-1, 1:-1] @staticmethod def _fill_diagonals(data: np.ndarray, fill_value=np.nan) -> np.ndarray: """ Remote the diagonal entries of the given array. :param data: The data to fill. :param fill_value: The value to fill the diagonal with. :return: The filled data. """ if len(data) == 0: return data if len(data) == 1: return np.array([[fill_value]]) np.fill_diagonal(data, fill_value) data = np.fliplr(data) np.fill_diagonal(data, fill_value) data = np.fliplr(data) return data @classmethod def _mask(cls, data: np.ndarray = None) -> np.ndarray: """ Remove diagonal and monomorphic entries. :param data: The data to mask. :return: The masked data. """ data = data.copy() data = cls._fill_diagonals(data) data = cls._remove_monomorphic(data) return data[~np.isnan(data) & ~np.isinf(data)] @classmethod def _get_max_abs_entry(cls, data: np.ndarray) -> float: """ Get the maximum absolute entry of the given data. :param data: The data. :return: The maximum absolute entry. """ entries = np.abs(cls._mask(data)) return entries.max() if len(entries) > 0 else 1