Source code for bullkpy.pl.heatmap_de

from __future__ import annotations

from pathlib import Path
from typing import Literal, Sequence

import numpy as np
import pandas as pd
import scipy.sparse as sp
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import to_hex

from ._style import set_style, _savefig, _apply_clustergrid_style

try:
    import seaborn as sns
except Exception:  # pragma: no cover
    sns = None

import anndata as ad

from ._style import set_style, _savefig


def _get_matrix(adata: ad.AnnData, layer: str | None) -> np.ndarray:
    X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
    if sp.issparse(X):
        X = X.toarray()
    return np.asarray(X, dtype=float)


[docs] def heatmap_de( adata: ad.AnnData, *, contrast: str, de_key: str = "de", results_key: str = "results", layer: str | None = "log1p_cpm", groupby: str | None = None, groups: Sequence[str] | None = None, top_n: int = 50, mode: Literal["up", "down", "both", "abs"] = "both", sort_by: Literal["qval", "pval", "log2FC", "t"] = "qval", z_score: Literal["row", "none"] = "row", clip_z: float | None = 3.0, cmap: str = "vlag", center: float = 0.0, col_colors: str | Sequence[str] | None = None, # obs key(s) for column annotation dendrogram_rows: bool = True, dendrogram_cols: bool = True, show_sample_labels: bool = False, figsize: tuple[float, float] | None = None, cbar_label: str = "z-scored expression", save: str | Path | None = None, show: bool = True, ): """ Heatmap of top DE genes using results stored at: adata.uns[de_key][contrast][results_key] Expected DE columns: ['gene', 'log2FC', 't', 'pval', 'qval', 'mean_group', 'mean_ref'] - Heatmap values come from `layer` (default: log1p_cpm) - Gene selection uses DE table (default sort: qval) - Optionally subset/order samples by `groupby` and `groups` """ set_style() if sns is None: raise ImportError("heatmap_de requires seaborn. Please install seaborn.") # --- fetch DE results --- try: res = adata.uns[de_key][contrast][results_key] except KeyError as e: raise KeyError( f"Could not find DE results at adata.uns['{de_key}']['{contrast}']['{results_key}']" ) from e if not isinstance(res, pd.DataFrame): res = pd.DataFrame(res) required = {"gene", "log2FC"} missing = required - set(res.columns) if missing: raise ValueError(f"DE results missing required columns: {sorted(missing)}") # --- gene selection --- res = res.copy() res["gene"] = res["gene"].astype(str) # pick sorting column if sort_by not in res.columns: # fallback sort_by = "qval" if "qval" in res.columns else ("pval" if "pval" in res.columns else "log2FC") # p/q ascending, effect size descending if sort_by in {"qval", "pval"}: res = res.sort_values(sort_by, ascending=True) else: res = res.sort_values(sort_by, ascending=False) if mode == "up": pick = res[res["log2FC"] > 0].head(top_n) elif mode == "down": pick = res[res["log2FC"] < 0].head(top_n) elif mode == "abs": pick = res.assign(_abs=np.abs(res["log2FC"])).sort_values("_abs", ascending=False).head(top_n) else: # both n1 = top_n // 2 n2 = top_n - n1 up = res[res["log2FC"] > 0].head(n1) down = res[res["log2FC"] < 0].head(n2) pick = pd.concat([up, down], axis=0) genes = [g for g in pick["gene"].tolist() if g in adata.var_names] if len(genes) == 0: raise ValueError("None of the selected DE genes are present in adata.var_names.") # --- subset/order samples --- obs_mask = np.ones(adata.n_obs, dtype=bool) col_order = np.arange(adata.n_obs) if groupby is not None: if groupby not in adata.obs.columns: raise KeyError(f"groupby='{groupby}' not found in adata.obs") g = adata.obs[groupby].astype(str) if groups is not None: groups = [str(x) for x in groups] obs_mask = g.isin(groups).to_numpy() g_sub = g[obs_mask] # keep group order as provided if groups given; else alphabetical if groups is None: cat_order = sorted(g_sub.unique()) else: cat_order = [x for x in groups if x in set(g_sub)] col_order = np.argsort(pd.Categorical(g_sub, categories=cat_order, ordered=True)) else: g_sub = None # --- expression matrix: genes x samples --- X = _get_matrix(adata, layer) gidx = [adata.var_names.get_loc(g) for g in genes] Xg = X[:, gidx] # samples x genes Xg = Xg[obs_mask, :] # subset samples Xh = Xg.T # genes x samples # --- z-score per gene --- if z_score == "row": mu = Xh.mean(axis=1, keepdims=True) sd = Xh.std(axis=1, keepdims=True, ddof=0) sd[sd == 0] = 1.0 Xh = (Xh - mu) / sd if clip_z is not None: Xh = np.clip(Xh, -float(clip_z), float(clip_z)) # order columns if groupby is not None: Xh = Xh[:, col_order] g_sub = g_sub.iloc[col_order] # dataframe colnames = adata.obs_names[obs_mask].astype(str).tolist() if groupby is not None: colnames = [colnames[i] for i in col_order] df = pd.DataFrame(Xh, index=genes, columns=colnames) # --- column annotations --- # --- column annotations (map categories -> colors) --- col_colors_df = None if col_colors is not None: if isinstance(col_colors, str): col_colors = [col_colors] ann_colors = {} for key in col_colors: if key not in adata.obs.columns: raise KeyError(f"col_colors obs key '{key}' not found") vals = adata.obs.loc[df.columns, key] # categorical/stringify for stable mapping vals_str = vals.astype(str) cats = pd.Categorical(vals_str).categories # build a palette (tab20 is good default) palette = sns.color_palette("tab20", n_colors=len(cats)) lut = {cat: to_hex(palette[i]) for i, cat in enumerate(cats)} ann_colors[key] = vals_str.map(lut) col_colors_df = pd.DataFrame(ann_colors, index=df.columns) # --- autosize --- if figsize is None: w = max(6.0, 0.15 * df.shape[1] + 3.0) h = max(4.8, 0.18 * df.shape[0] + 2.2) figsize = (w, h) # --- plot --- cg = sns.clustermap( df, cmap=cmap, center=center if z_score == "row" else None, row_cluster=bool(dendrogram_rows), col_cluster=bool(dendrogram_cols), col_colors=col_colors_df, xticklabels=show_sample_labels, yticklabels=True, figsize=figsize, cbar_kws={"label": cbar_label}, ) _apply_clustergrid_style(cg) cg.fig.tight_layout() # title cg.ax_heatmap.set_title(f"{contrast} — top {top_n} ({mode})", pad=10) if save is not None: _savefig(cg.fig, save) if show: plt.show() return cg