Source code for bullkpy.pl.corrplot_obs

from __future__ import annotations

from pathlib import Path
from typing import Literal, Sequence

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.sparse as sp

import anndata as ad

from ._style import set_style, _savefig


def _get_color_vector(
    adata: ad.AnnData,
    key: str,
    *,
    layer: str | None = None,
) -> tuple[np.ndarray, str, str]:
    """
    Returns (values, kind, label)
      kind in {"numeric", "categorical"}.
    key can be obs column or gene in var_names.
    """
    # obs
    if key in adata.obs.columns:
        s = adata.obs[key]
        if pd.api.types.is_numeric_dtype(s):
            return s.to_numpy(dtype=float), "numeric", key
        return s.astype(str).to_numpy(dtype=object), "categorical", key

    # gene
    if key in adata.var_names:
        X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
        gidx = int(adata.var_names.get_loc(key))
        if sp.issparse(X):
            vals = np.asarray(X[:, gidx].toarray()).ravel()
        else:
            vals = np.asarray(X[:, gidx], dtype=float).ravel()
        return vals.astype(float), "numeric", key

    raise KeyError(f"color/hue key '{key}' not found in adata.obs columns or adata.var_names.")


def _corr_stats(
    xvals: np.ndarray,
    yvals: np.ndarray,
    *,
    method: Literal["pearson", "spearman", "both"] = "both",
) -> dict[str, float | int]:
    out: dict[str, float | int] = {"n": int(len(xvals))}
    if method in ("pearson", "both"):
        try:
            from scipy.stats import pearsonr
            r, p = pearsonr(xvals, yvals)
            out["pearson_r"] = float(r)
            out["pearson_p"] = float(p)
        except Exception:
            out["pearson_r"] = float("nan")
            out["pearson_p"] = float("nan")
    if method in ("spearman", "both"):
        try:
            from scipy.stats import spearmanr
            rho, p = spearmanr(xvals, yvals, nan_policy="omit")
            out["spearman_rho"] = float(rho)
            out["spearman_p"] = float(p)
        except Exception:
            out["spearman_rho"] = float("nan")
            out["spearman_p"] = float("nan")
    return out


def _plot_one(
    ax: plt.Axes,
    df: pd.DataFrame,
    *,
    x: str,
    y: str,
    color_key: str | None,
    layer: str | None,
    palette: str,
    cmap: str,
    point_size: float,
    alpha: float,
    add_regline: bool,
    annotate: bool,
    method: Literal["pearson", "spearman", "both"],
    title: str | None,
    legend: bool,
) -> dict[str, float | int]:
    xvals = df[x].to_numpy(float)
    yvals = df[y].to_numpy(float)

    stats = _corr_stats(xvals, yvals, method=method)

    slope = intercept = None
    if add_regline:
        slope, intercept = np.polyfit(xvals, yvals, 1)
        stats["slope"] = float(slope)
        stats["intercept"] = float(intercept)

    if color_key is None:
        ax.scatter(xvals, yvals, s=point_size, alpha=alpha, edgecolors="none")
    else:
        vals, kind, label = _get_color_vector(adata=df.attrs["_adata_"], key=color_key, layer=layer)
        # align to df index
        vals = pd.Series(vals, index=df.attrs["_adata_"].obs_names).reindex(df.index).to_numpy()

        if kind == "numeric":
            sc = ax.scatter(
                xvals, yvals,
                c=vals.astype(float),
                cmap=cmap,
                s=point_size,
                alpha=alpha,
                edgecolors="none",
            )
            if legend:
                cbar = ax.figure.colorbar(sc, ax=ax, pad=0.02, fraction=0.05)
                cbar.set_label(label)
        else:
            cats = pd.Categorical(pd.Series(vals).astype(str))
            names = list(cats.categories)

            pal = mpl.cm.get_cmap(palette) if palette in plt.colormaps() else None
            if pal is None:
                pal = mpl.cm.get_cmap("tab20")
            cmap_map = {str(n): pal(i % pal.N) for i, n in enumerate(names)}

            for name in names:
                m = (cats == name)
                ax.scatter(
                    xvals[m],
                    yvals[m],
                    s=point_size,
                    alpha=alpha,
                    edgecolors="none",
                    color=cmap_map[str(name)],
                    label=str(name),
                )

            if legend:
                ax.legend(
                    title=label,
                    bbox_to_anchor=(1.02, 1),
                    loc="upper left",
                    frameon=False,
                )

    ax.set_xlabel(x)
    ax.set_ylabel(y)

    if add_regline and slope is not None and intercept is not None:
        xs = np.linspace(np.nanmin(xvals), np.nanmax(xvals), 200)
        ax.plot(xs, slope * xs + intercept)

    if annotate:
        lines = [f"n={stats['n']}"]
        if "pearson_r" in stats:
            lines.append(f"Pearson r={stats['pearson_r']:.3g}, p={stats['pearson_p']:.2g}")
        if "spearman_rho" in stats:
            lines.append(f"Spearman ρ={stats['spearman_rho']:.3g}, p={stats['spearman_p']:.2g}")
        ax.text(0.02, 0.98, "\n".join(lines), transform=ax.transAxes, ha="left", va="top")

    ax.set_title(title or (color_key if color_key is not None else f"{x} vs {y}"))
    return stats


[docs] def corrplot_obs( adata: ad.AnnData, *, x: str, y: str, color: str | Sequence[str] | None = None, hue: str | Sequence[str] | None = None, # alias for color layer: str | None = None, palette: str = "tab20", cmap: str = "viridis", legend: bool = True, method: Literal["pearson", "spearman", "both"] = "both", add_regline: bool = True, annotate: bool = True, dropna: bool = True, point_size: float = 18.0, alpha: float = 0.75, figsize: tuple[float, float] = (5.5, 4.5), panel_size: tuple[float, float] | None = None, title: str | None = None, save: str | Path | None = None, show: bool = True, ): """ Scatter + correlations between two quantitative obs columns. Multi-panel: color=["DLL3","SOX10"] makes one panel per color key in a single row. """ set_style() if hue is not None and color is None: color = hue if x not in adata.obs.columns: raise KeyError(f"x='{x}' not in adata.obs") if y not in adata.obs.columns: raise KeyError(f"y='{y}' not in adata.obs") df = adata.obs[[x, y]].copy() df[x] = pd.to_numeric(df[x], errors="coerce") df[y] = pd.to_numeric(df[y], errors="coerce") # store for _plot_one df.attrs["_adata_"] = adata if dropna: df = df.dropna(subset=[x, y]) if df.shape[0] < 3: raise ValueError(f"Not enough valid points for correlation: n={df.shape[0]}") # normalize color -> list if color is None: colors = [None] elif isinstance(color, (list, tuple)): colors = [str(c) for c in color] else: colors = [str(color)] n = len(colors) if panel_size is not None: figsize_eff = (panel_size[0] * n, panel_size[1]) else: figsize_eff = (figsize[0] * n, figsize[1]) if n > 1 else figsize fig, axes = plt.subplots(1, n, figsize=figsize_eff, constrained_layout=True) if n == 1: axes = [axes] stats_list: list[dict[str, float | int]] = [] for ax, c in zip(axes, colors): st = _plot_one( ax, df, x=x, y=y, color_key=c, layer=layer, palette=palette, cmap=cmap, point_size=point_size, alpha=alpha, add_regline=add_regline, annotate=annotate, method=method, title=(title if (title is not None and n == 1) else None), legend=legend, ) stats_list.append(st) if save is not None: _savefig(fig, save) if show: plt.show() return fig, np.array(axes, dtype=object), stats_list