from __future__ import annotations
from pathlib import Path
from typing import Literal, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import anndata as ad
from ._style import set_style, _savefig
Metric = Literal["ARI", "NMI", "cramers_v"]
[docs]
def ari_resolution_heatmap(
adata: ad.AnnData,
*,
df: pd.DataFrame | None = None,
store_key: str = "leiden_scan",
metric: Metric = "ARI",
show_n_clusters: bool = True,
cmap: str = "viridis",
vmin: float | None = None,
vmax: float | None = None,
figsize: tuple[float, float] | None = None,
title: str | None = None,
save: str | Path | None = None,
show: bool = True,
):
"""
Heatmap-like summary of clustering quality vs Leiden resolution.
Expects `df` with at least:
- 'resolution'
- metric column: 'ARI' or 'NMI' or 'cramers_v'
Optional:
- 'n_clusters' (for a second row annotation)
Example:
df = bk.tl.leiden_resolution_scan(...)
bk.pl.ari_resolution_heatmap(adata, df=df, metric="ARI")
"""
set_style()
if df is None:
if store_key not in adata.uns:
raise KeyError(
f"adata.uns['{store_key}'] not found. Run bk.tl.leiden_resolution_scan(...) first "
f"or pass df=..."
)
df = adata.uns[store_key]
if not isinstance(df, pd.DataFrame):
df = pd.DataFrame(df)
required = {"resolution", metric}
missing = required - set(df.columns)
if missing:
raise KeyError(f"Missing columns in df: {sorted(missing)}")
d = df.copy().sort_values("resolution").reset_index(drop=True)
# columns = resolutions (as strings for tick labels)
cols = [f"{r:g}" for r in d["resolution"].astype(float).to_numpy()]
n = len(cols)
metric_vals = np.asarray(d[metric].to_numpy(dtype=float))[None, :] # (1, n)
have_clusters = bool(show_n_clusters and ("n_clusters" in d.columns))
if have_clusters:
cluster_vals = np.asarray(d["n_clusters"].to_numpy(dtype=float))[None, :] # (1, n)
rows = [metric] + (["n_clusters"] if have_clusters else [])
n_rows = len(rows)
# default size
if figsize is None:
w = max(6.0, 0.55 * n + 1.8)
h = 2.3 if n_rows == 1 else 3.2
figsize = (w, h)
fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)
# ---- draw images with explicit extents (no ax.images.pop) ----
# We use a coordinate system where:
# x in [0, n], y in [0, n_rows]
# and each cell is 1×1 centered at (j+0.5, i+0.5).
im_metric = ax.imshow(
metric_vals,
aspect="auto",
interpolation="nearest",
cmap=cmap,
vmin=vmin,
vmax=vmax,
extent=(0, n, 0, 1), # row 0
origin="lower",
)
if have_clusters:
ax.imshow(
cluster_vals,
aspect="auto",
interpolation="nearest",
cmap="Greys",
extent=(0, n, 1, 2), # row 1
origin="lower",
)
# ---- ticks/labels ----
ax.set_xlim(0, n)
ax.set_ylim(0, n_rows)
ax.set_xticks(np.arange(n) + 0.5)
ax.set_xticklabels(cols, rotation=45, ha="right")
ax.set_xlabel("Leiden resolution")
ax.set_yticks(np.arange(n_rows) + 0.5)
ax.set_yticklabels(rows)
# gridlines between cells (scanpy-ish)
ax.set_xticks(np.arange(n + 1), minor=True)
ax.set_yticks(np.arange(n_rows + 1), minor=True)
ax.grid(which="minor", color="white", linestyle="-", linewidth=1.5)
ax.tick_params(which="minor", bottom=False, left=False)
# ---- annotate metric cells ----
for j in range(n):
v = float(metric_vals[0, j])
if np.isfinite(v):
ax.text(j + 0.5, 0.5, f"{v:.2f}", ha="center", va="center", fontsize=9, color="white")
if have_clusters:
for j in range(n):
v = float(cluster_vals[0, j])
if np.isfinite(v):
ax.text(j + 0.5, 1.5, f"{int(round(v))}", ha="center", va="center", fontsize=9, color="black")
ax.set_title(title or f"{metric} vs Leiden resolution")
# colorbar for metric only
cbar = fig.colorbar(im_metric, ax=ax, fraction=0.035, pad=0.02)
cbar.set_label(metric)
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, ax
def categorical_confusion(
adata: ad.AnnData,
*,
key1: str,
key2: str,
normalize: Literal["none", "row", "col", "all"] = "row",
cmap: str = "Blues",
figsize: tuple[float, float] | None = None,
title: str | None = None,
dropna: bool = True,
min_count: int = 1,
show: bool = True,
save: str | Path | None = None,
):
"""
Confusion-style heatmap for two categorical obs columns + association metrics.
Uses bk.tl.categorical_confusion() for computation (counts + normalized matrix + metrics).
"""
set_style()
# local import to avoid circular imports at module import time
from bullkpy.tl import categorical_confusion as _categorical_confusion_tl
res = _categorical_confusion_tl(
adata,
key1=key1,
key2=key2,
normalize=normalize,
dropna=dropna,
min_count=min_count,
)
tab: pd.DataFrame = res["table"]
mat: np.ndarray = res["matrix"]
metrics: dict = res.get("metrics", {}) or {}
# auto figsize
if figsize is None:
figsize = (max(6, 0.28 * tab.shape[1] + 3.0), max(4, 0.28 * tab.shape[0] + 2.0))
fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)
vmin = 0.0
vmax = 1.0 if normalize != "none" else None
im = ax.imshow(mat, aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax)
ax.set_xticks(np.arange(tab.shape[1]))
ax.set_xticklabels(tab.columns.tolist(), rotation=90)
ax.set_yticks(np.arange(tab.shape[0]))
ax.set_yticklabels(tab.index.tolist())
# Title with metrics
crv = metrics.get("cramers_v", np.nan)
ari = metrics.get("ari", np.nan)
nmi = metrics.get("nmi", np.nan)
ttl = title or f"{key1} vs {key2}"
if np.isfinite(crv):
ttl += f" | Cramér’s V={float(crv):.3f}"
if np.isfinite(ari):
ttl += f" | ARI={float(ari):.3f}"
if np.isfinite(nmi):
ttl += f" | NMI={float(nmi):.3f}"
ax.set_title(ttl)
cbar = fig.colorbar(im, ax=ax, pad=0.01)
cbar.set_label("fraction" if normalize != "none" else "count")
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, ax, res