Source code for bullkpy.pl.corrplot

from __future__ import annotations

from pathlib import Path
from typing import Literal, Sequence

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.sparse as sp

import anndata as ad

from ._style import set_style, _savefig


def _get_color_vector(
    adata: ad.AnnData,
    key: str,
    *,
    layer: str | None = None,
) -> tuple[np.ndarray, str, str]:
    """
    Returns (values, kind, label)
      kind in {"numeric", "categorical"}.
    key can be obs column or gene in var_names.
    """
    # obs
    if key in adata.obs.columns:
        s = adata.obs[key]
        if pd.api.types.is_numeric_dtype(s):
            return s.to_numpy(dtype=float), "numeric", key
        return s.astype(str).to_numpy(dtype=object), "categorical", key

    # gene
    if key in adata.var_names:
        X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
        gidx = int(adata.var_names.get_loc(key))
        if sp.issparse(X):
            vals = np.asarray(X[:, gidx].toarray()).ravel()
        else:
            vals = np.asarray(X[:, gidx], dtype=float).ravel()
        return vals.astype(float), "numeric", key

    raise KeyError(f"color/hue key '{key}' not found in adata.obs columns or adata.var_names.")


def _get_numeric_vector(
    adata: ad.AnnData,
    key: str,
    *,
    source: Literal["auto", "obs", "gene"] = "auto",
    layer: str | None = None,
) -> tuple[np.ndarray, str]:
    """
    Return (values, label) where values is float array aligned to adata.obs_names.

    - obs: pd.to_numeric(errors="coerce")
    - gene: from X or layer
    """
    def _from_obs(k: str) -> tuple[np.ndarray, str]:
        if k not in adata.obs.columns:
            raise KeyError(f"'{k}' not found in adata.obs (source='obs').")
        s = pd.to_numeric(adata.obs[k], errors="coerce")
        return s.to_numpy(dtype=float), k

    def _from_gene(g: str) -> tuple[np.ndarray, str]:
        if g not in adata.var_names:
            raise KeyError(f"Gene '{g}' not found in adata.var_names (source='gene').")
        X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
        j = int(adata.var_names.get_loc(g))
        v = X[:, j]
        if sp.issparse(v):
            v = v.toarray()
        return np.asarray(v, dtype=float).ravel(), g

    if source == "obs":
        return _from_obs(key)
    if source == "gene":
        return _from_gene(key)

    # auto
    if key in adata.obs.columns:
        return _from_obs(key)
    if key in adata.var_names:
        return _from_gene(key)
    raise KeyError(f"'{key}' not found in adata.obs columns or adata.var_names (source='auto').")


def _corr_stats(
    xvals: np.ndarray,
    yvals: np.ndarray,
    *,
    method: Literal["pearson", "spearman", "both"] = "both",
) -> dict[str, float | int]:
    out: dict[str, float | int] = {"n": int(len(xvals))}
    if method in ("pearson", "both"):
        try:
            from scipy.stats import pearsonr
            r, p = pearsonr(xvals, yvals)
            out["pearson_r"] = float(r)
            out["pearson_p"] = float(p)
        except Exception:
            out["pearson_r"] = float("nan")
            out["pearson_p"] = float("nan")
    if method in ("spearman", "both"):
        try:
            from scipy.stats import spearmanr
            rho, p = spearmanr(xvals, yvals, nan_policy="omit")
            out["spearman_rho"] = float(rho)
            out["spearman_p"] = float(p)
        except Exception:
            out["spearman_rho"] = float("nan")
            out["spearman_p"] = float("nan")
    return out


def _plot_one(
    ax: plt.Axes,
    df: pd.DataFrame,
    *,
    x: str,
    y: str,
    color_key: str | None,
    layer: str | None,
    palette: str,
    cmap: str,
    point_size: float,
    alpha: float,
    add_regline: bool,
    annotate: bool,
    method: Literal["pearson", "spearman", "both"],
    title: str | None,
    legend: bool,
) -> dict[str, float | int]:
    xvals = df[x].to_numpy(float)
    yvals = df[y].to_numpy(float)

    stats = _corr_stats(xvals, yvals, method=method)

    slope = intercept = None
    if add_regline:
        slope, intercept = np.polyfit(xvals, yvals, 1)
        stats["slope"] = float(slope)
        stats["intercept"] = float(intercept)

    if color_key is None:
        ax.scatter(xvals, yvals, s=point_size, alpha=alpha, edgecolors="none")
    else:
        vals, kind, label = _get_color_vector(adata=df.attrs["_adata_"], key=color_key, layer=layer)
        vals = pd.Series(vals, index=df.attrs["_adata_"].obs_names).reindex(df.index).to_numpy()

        if kind == "numeric":
            sc = ax.scatter(
                xvals, yvals,
                c=vals.astype(float),
                cmap=cmap,
                s=point_size,
                alpha=alpha,
                edgecolors="none",
            )
            if legend:
                cbar = ax.figure.colorbar(sc, ax=ax, pad=0.02, fraction=0.05)
                cbar.set_label(label)
        else:
            cats = pd.Categorical(pd.Series(vals).astype(str))
            names = list(cats.categories)

            pal = mpl.cm.get_cmap(palette) if palette in plt.colormaps() else None
            if pal is None:
                pal = mpl.cm.get_cmap("tab20")
            cmap_map = {str(n): pal(i % pal.N) for i, n in enumerate(names)}

            for name in names:
                m = (cats == name)
                ax.scatter(
                    xvals[m],
                    yvals[m],
                    s=point_size,
                    alpha=alpha,
                    edgecolors="none",
                    color=cmap_map[str(name)],
                    label=str(name),
                )

            if legend:
                ax.legend(
                    title=label,
                    bbox_to_anchor=(1.02, 1),
                    loc="upper left",
                    frameon=False,
                )

    ax.set_xlabel(x)
    ax.set_ylabel(y)

    if add_regline and slope is not None and intercept is not None:
        xs = np.linspace(np.nanmin(xvals), np.nanmax(xvals), 200)
        ax.plot(xs, slope * xs + intercept)

    if annotate:
        lines = [f"n={stats['n']}"]
        if "pearson_r" in stats:
            lines.append(f"Pearson r={stats['pearson_r']:.3g}, p={stats['pearson_p']:.2g}")
        if "spearman_rho" in stats:
            lines.append(f"Spearman ρ={stats['spearman_rho']:.3g}, p={stats['spearman_p']:.2g}")
        ax.text(0.02, 0.98, "\n".join(lines), transform=ax.transAxes, ha="left", va="top")

    ax.set_title(title or (color_key if color_key is not None else f"{x} vs {y}"))
    return stats


[docs] def corrplot( adata: ad.AnnData, *, x: str, y: str, x_source: Literal["auto", "obs", "gene"] = "auto", y_source: Literal["auto", "obs", "gene"] = "auto", color: str | Sequence[str] | None = None, hue: str | Sequence[str] | None = None, # alias for color layer: str | None = None, palette: str = "tab20", cmap: str = "viridis", legend: bool = True, method: Literal["pearson", "spearman", "both"] = "both", add_regline: bool = True, annotate: bool = True, dropna: bool = True, point_size: float = 18.0, alpha: float = 0.75, figsize: tuple[float, float] = (5.5, 4.5), panel_size: tuple[float, float] | None = None, title: str | None = None, save: str | Path | None = None, show: bool = True, ): """ Scatter + correlations between two quantitative vectors. Correlation scatter plot between two vectors. x and y can be: - obs columns - genes (from X or layer) Supports: - obs vs obs - gene vs gene - gene vs obs Examples -------- # gene vs gene bk.pl.corrplot(adata, x="MKI67", y="TOP2A", x_source="gene", y_source="gene", layer="log1p_cpm") # gene vs obs bk.pl.corrplot(adata, x="MKI67", y="Proliferation_score", x_source="gene", y_source="obs") # obs vs obs (auto) bk.pl.corrplot(adata, x="mp_entropy", y="purity") """ set_style() if hue is not None and color is None: color = hue xvals, xlabel = _get_numeric_vector(adata, x, source=x_source, layer=layer) yvals, ylabel = _get_numeric_vector(adata, y, source=y_source, layer=layer) df = pd.DataFrame({xlabel: xvals, ylabel: yvals}, index=adata.obs_names.astype(str)) df.attrs["_adata_"] = adata if dropna: df = df.dropna(subset=[xlabel, ylabel]) if df.shape[0] < 3: raise ValueError(f"Not enough valid points for correlation: n={df.shape[0]}") # normalize color -> list if color is None: colors = [None] elif isinstance(color, (list, tuple)): colors = [str(c) for c in color] else: colors = [str(color)] n = len(colors) if panel_size is not None: figsize_eff = (panel_size[0] * n, panel_size[1]) else: figsize_eff = (figsize[0] * n, figsize[1]) if n > 1 else figsize fig, axes = plt.subplots(1, n, figsize=figsize_eff, constrained_layout=True) if n == 1: axes = [axes] stats_list: list[dict[str, float | int]] = [] for ax, c in zip(axes, colors): st = _plot_one( ax, df, x=xlabel, y=ylabel, color_key=c, layer=layer, palette=palette, cmap=cmap, point_size=point_size, alpha=alpha, add_regline=add_regline, annotate=annotate, method=method, title=(title if (title is not None and n == 1) else None), legend=legend, ) stats_list.append(st) if save is not None: _savefig(fig, save) if show: plt.show() return fig, np.array(axes, dtype=object), stats_list