from __future__ import annotations
from pathlib import Path
from typing import Literal, Sequence
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.sparse as sp
import anndata as ad
from ._style import set_style, _savefig
def _get_color_vector(
adata: ad.AnnData,
key: str,
*,
layer: str | None = None,
) -> tuple[np.ndarray, str, str]:
"""
Returns (values, kind, label)
kind in {"numeric", "categorical"}.
key can be obs column or gene in var_names.
"""
# obs
if key in adata.obs.columns:
s = adata.obs[key]
if pd.api.types.is_numeric_dtype(s):
return s.to_numpy(dtype=float), "numeric", key
return s.astype(str).to_numpy(dtype=object), "categorical", key
# gene
if key in adata.var_names:
X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
gidx = int(adata.var_names.get_loc(key))
if sp.issparse(X):
vals = np.asarray(X[:, gidx].toarray()).ravel()
else:
vals = np.asarray(X[:, gidx], dtype=float).ravel()
return vals.astype(float), "numeric", key
raise KeyError(f"color/hue key '{key}' not found in adata.obs columns or adata.var_names.")
def _get_numeric_vector(
adata: ad.AnnData,
key: str,
*,
source: Literal["auto", "obs", "gene"] = "auto",
layer: str | None = None,
) -> tuple[np.ndarray, str]:
"""
Return (values, label) where values is float array aligned to adata.obs_names.
- obs: pd.to_numeric(errors="coerce")
- gene: from X or layer
"""
def _from_obs(k: str) -> tuple[np.ndarray, str]:
if k not in adata.obs.columns:
raise KeyError(f"'{k}' not found in adata.obs (source='obs').")
s = pd.to_numeric(adata.obs[k], errors="coerce")
return s.to_numpy(dtype=float), k
def _from_gene(g: str) -> tuple[np.ndarray, str]:
if g not in adata.var_names:
raise KeyError(f"Gene '{g}' not found in adata.var_names (source='gene').")
X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
j = int(adata.var_names.get_loc(g))
v = X[:, j]
if sp.issparse(v):
v = v.toarray()
return np.asarray(v, dtype=float).ravel(), g
if source == "obs":
return _from_obs(key)
if source == "gene":
return _from_gene(key)
# auto
if key in adata.obs.columns:
return _from_obs(key)
if key in adata.var_names:
return _from_gene(key)
raise KeyError(f"'{key}' not found in adata.obs columns or adata.var_names (source='auto').")
def _corr_stats(
xvals: np.ndarray,
yvals: np.ndarray,
*,
method: Literal["pearson", "spearman", "both"] = "both",
) -> dict[str, float | int]:
out: dict[str, float | int] = {"n": int(len(xvals))}
if method in ("pearson", "both"):
try:
from scipy.stats import pearsonr
r, p = pearsonr(xvals, yvals)
out["pearson_r"] = float(r)
out["pearson_p"] = float(p)
except Exception:
out["pearson_r"] = float("nan")
out["pearson_p"] = float("nan")
if method in ("spearman", "both"):
try:
from scipy.stats import spearmanr
rho, p = spearmanr(xvals, yvals, nan_policy="omit")
out["spearman_rho"] = float(rho)
out["spearman_p"] = float(p)
except Exception:
out["spearman_rho"] = float("nan")
out["spearman_p"] = float("nan")
return out
def _plot_one(
ax: plt.Axes,
df: pd.DataFrame,
*,
x: str,
y: str,
color_key: str | None,
layer: str | None,
palette: str,
cmap: str,
point_size: float,
alpha: float,
add_regline: bool,
annotate: bool,
method: Literal["pearson", "spearman", "both"],
title: str | None,
legend: bool,
) -> dict[str, float | int]:
xvals = df[x].to_numpy(float)
yvals = df[y].to_numpy(float)
stats = _corr_stats(xvals, yvals, method=method)
slope = intercept = None
if add_regline:
slope, intercept = np.polyfit(xvals, yvals, 1)
stats["slope"] = float(slope)
stats["intercept"] = float(intercept)
if color_key is None:
ax.scatter(xvals, yvals, s=point_size, alpha=alpha, edgecolors="none")
else:
vals, kind, label = _get_color_vector(adata=df.attrs["_adata_"], key=color_key, layer=layer)
vals = pd.Series(vals, index=df.attrs["_adata_"].obs_names).reindex(df.index).to_numpy()
if kind == "numeric":
sc = ax.scatter(
xvals, yvals,
c=vals.astype(float),
cmap=cmap,
s=point_size,
alpha=alpha,
edgecolors="none",
)
if legend:
cbar = ax.figure.colorbar(sc, ax=ax, pad=0.02, fraction=0.05)
cbar.set_label(label)
else:
cats = pd.Categorical(pd.Series(vals).astype(str))
names = list(cats.categories)
pal = mpl.cm.get_cmap(palette) if palette in plt.colormaps() else None
if pal is None:
pal = mpl.cm.get_cmap("tab20")
cmap_map = {str(n): pal(i % pal.N) for i, n in enumerate(names)}
for name in names:
m = (cats == name)
ax.scatter(
xvals[m],
yvals[m],
s=point_size,
alpha=alpha,
edgecolors="none",
color=cmap_map[str(name)],
label=str(name),
)
if legend:
ax.legend(
title=label,
bbox_to_anchor=(1.02, 1),
loc="upper left",
frameon=False,
)
ax.set_xlabel(x)
ax.set_ylabel(y)
if add_regline and slope is not None and intercept is not None:
xs = np.linspace(np.nanmin(xvals), np.nanmax(xvals), 200)
ax.plot(xs, slope * xs + intercept)
if annotate:
lines = [f"n={stats['n']}"]
if "pearson_r" in stats:
lines.append(f"Pearson r={stats['pearson_r']:.3g}, p={stats['pearson_p']:.2g}")
if "spearman_rho" in stats:
lines.append(f"Spearman ρ={stats['spearman_rho']:.3g}, p={stats['spearman_p']:.2g}")
ax.text(0.02, 0.98, "\n".join(lines), transform=ax.transAxes, ha="left", va="top")
ax.set_title(title or (color_key if color_key is not None else f"{x} vs {y}"))
return stats
[docs]
def corrplot(
adata: ad.AnnData,
*,
x: str,
y: str,
x_source: Literal["auto", "obs", "gene"] = "auto",
y_source: Literal["auto", "obs", "gene"] = "auto",
color: str | Sequence[str] | None = None,
hue: str | Sequence[str] | None = None, # alias for color
layer: str | None = None,
palette: str = "tab20",
cmap: str = "viridis",
legend: bool = True,
method: Literal["pearson", "spearman", "both"] = "both",
add_regline: bool = True,
annotate: bool = True,
dropna: bool = True,
point_size: float = 18.0,
alpha: float = 0.75,
figsize: tuple[float, float] = (5.5, 4.5),
panel_size: tuple[float, float] | None = None,
title: str | None = None,
save: str | Path | None = None,
show: bool = True,
):
"""
Scatter + correlations between two quantitative vectors.
Correlation scatter plot between two vectors.
x and y can be:
- obs columns
- genes (from X or layer)
Supports:
- obs vs obs
- gene vs gene
- gene vs obs
Examples
--------
# gene vs gene
bk.pl.corrplot(adata, x="MKI67", y="TOP2A", x_source="gene", y_source="gene", layer="log1p_cpm")
# gene vs obs
bk.pl.corrplot(adata, x="MKI67", y="Proliferation_score", x_source="gene", y_source="obs")
# obs vs obs (auto)
bk.pl.corrplot(adata, x="mp_entropy", y="purity")
"""
set_style()
if hue is not None and color is None:
color = hue
xvals, xlabel = _get_numeric_vector(adata, x, source=x_source, layer=layer)
yvals, ylabel = _get_numeric_vector(adata, y, source=y_source, layer=layer)
df = pd.DataFrame({xlabel: xvals, ylabel: yvals}, index=adata.obs_names.astype(str))
df.attrs["_adata_"] = adata
if dropna:
df = df.dropna(subset=[xlabel, ylabel])
if df.shape[0] < 3:
raise ValueError(f"Not enough valid points for correlation: n={df.shape[0]}")
# normalize color -> list
if color is None:
colors = [None]
elif isinstance(color, (list, tuple)):
colors = [str(c) for c in color]
else:
colors = [str(color)]
n = len(colors)
if panel_size is not None:
figsize_eff = (panel_size[0] * n, panel_size[1])
else:
figsize_eff = (figsize[0] * n, figsize[1]) if n > 1 else figsize
fig, axes = plt.subplots(1, n, figsize=figsize_eff, constrained_layout=True)
if n == 1:
axes = [axes]
stats_list: list[dict[str, float | int]] = []
for ax, c in zip(axes, colors):
st = _plot_one(
ax,
df,
x=xlabel,
y=ylabel,
color_key=c,
layer=layer,
palette=palette,
cmap=cmap,
point_size=point_size,
alpha=alpha,
add_regline=add_regline,
annotate=annotate,
method=method,
title=(title if (title is not None and n == 1) else None),
legend=legend,
)
stats_list.append(st)
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, np.array(axes, dtype=object), stats_list