from __future__ import annotations
from pathlib import Path
from typing import Literal, Sequence, Mapping, Any, Optional, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import anndata as ad
import scipy.sparse as sp
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests
from ._style import set_style, _savefig
# FIX: update this import to your new TL module location:
from ..tl.associations import categorical_association
# If instead you kept tl/categorical_association.py, use:
# from ..tl.categorical_association import categorical_association
[docs]
def association_heatmap(
df: pd.DataFrame,
*,
index: str,
columns: str,
values: str,
agg: Literal["mean", "max", "min"] = "mean",
cmap: str = "viridis",
vmin: float | None = None,
vmax: float | None = None,
figsize: tuple[float, float] | None = None,
title: str | None = None,
save: str | Path | None = None,
show: bool = True,
):
"""
Generic heatmap for association outputs.
Parameters
----------
df
Long-form association table (tidy dataframe).
index, columns, values
Column names used to build the pivot (heatmap matrix).
agg
Aggregation used when multiple rows map to the same (index, column) cell.
cmap, vmin, vmax
Matplotlib colormap and scaling.
figsize
Figure size in inches. If None, auto-sized from matrix shape.
title
Plot title.
save
Path to save the figure (png/pdf/svg).
show
Whether to display the figure.
Returns
-------
fig, ax
"""
set_style()
piv = pd.pivot_table(df, index=index, columns=columns, values=values, aggfunc=agg)
mat = piv.to_numpy(dtype=float)
if figsize is None:
figsize = (max(6, 0.25 * piv.shape[1] + 3), max(4, 0.25 * piv.shape[0] + 2))
fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)
im = ax.imshow(mat, aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax)
ax.set_xticks(np.arange(piv.shape[1]))
ax.set_xticklabels([str(c) for c in piv.columns], rotation=90)
ax.set_yticks(np.arange(piv.shape[0]))
ax.set_yticklabels([str(r) for r in piv.index])
ax.set_title(title or f"Heatmap: {values}")
cbar = fig.colorbar(im, ax=ax, pad=0.01)
cbar.set_label(values)
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, ax
[docs]
def boxplot_with_stats(
adata: ad.AnnData,
*,
y: str,
groupby: str,
figsize: tuple[float, float] = (7, 3.5),
kind: Literal["box", "violin"] = "violin",
show_points: bool = True,
point_size: float = 2.0,
point_alpha: float = 0.3,
title: str | None = None,
save: str | Path | None = None,
show: bool = True,
):
"""
Box/violin plot of a numeric obs column across categorical groups, with a
simple global p-value annotation.
- If 2 groups: Mann–Whitney U
- If >2 groups: Kruskal–Wallis
Returns (fig, ax).
"""
set_style()
if y not in adata.obs.columns:
raise KeyError(f"'{y}' not in adata.obs")
if groupby not in adata.obs.columns:
raise KeyError(f"'{groupby}' not in adata.obs")
s = pd.to_numeric(adata.obs[y], errors="coerce")
g = adata.obs[groupby].astype(str)
df = pd.DataFrame({"y": s, "g": g}).dropna()
cats = list(pd.Categorical(df["g"]).categories)
groups = [df.loc[df["g"] == c, "y"].to_numpy(dtype=float) for c in cats]
k_eff = sum(v.size > 0 for v in groups)
pval = np.nan
if k_eff >= 2:
if len(cats) == 2:
from scipy.stats import mannwhitneyu
pval = mannwhitneyu(groups[0], groups[1], alternative="two-sided").pvalue
else:
from scipy.stats import kruskal
pval = kruskal(*[v for v in groups if v.size > 0]).pvalue
fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)
positions = np.arange(len(cats))
if kind == "box":
ax.boxplot(groups, positions=positions, showfliers=False)
else:
parts = ax.violinplot(groups, positions=positions, showmeans=False, showextrema=False, showmedians=True)
for pc in parts.get("bodies", []):
pc.set_alpha(0.8)
if show_points:
rng = np.random.RandomState(0)
for i, v in enumerate(groups):
if v.size == 0:
continue
xj = rng.normal(loc=i, scale=0.06, size=v.size)
ax.scatter(xj, v, s=point_size, alpha=point_alpha, edgecolors="none")
ax.set_xticks(positions)
ax.set_xticklabels(cats, rotation=90)
ax.set_ylabel(y)
ttl = title or f"{y} by {groupby}"
if np.isfinite(pval):
ttl += f" (p={pval:.2e})"
ax.set_title(ttl)
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, ax
def categorical_confusion(
adata: ad.AnnData,
*,
key1: str,
key2: str,
normalize: Literal["none", "row", "col", "all"] = "row",
cmap: str = "Blues",
figsize: tuple[float, float] | None = None,
title: str | None = None,
save: str | Path | None = None,
show: bool = True,
):
"""
Confusion-style heatmap for two categorical obs columns + association metrics.
Metrics (from categorical_association):
- chi2
- cramers_v
- ari
- nmi
"""
#set_style()
from ..tl.associations import categorical_association
res = categorical_association(adata, key1=key1, key2=key2, )
#metrics=("chi2", "cramers_v", "ari", "nmi")
tab: pd.DataFrame = res["table"].copy()
mat = tab.to_numpy(dtype=float)
if normalize == "row":
mat = mat / np.maximum(mat.sum(axis=1, keepdims=True), 1.0)
elif normalize == "col":
mat = mat / np.maximum(mat.sum(axis=0, keepdims=True), 1.0)
elif normalize == "all":
mat = mat / np.maximum(mat.sum(), 1.0)
if figsize is None:
figsize = (max(6, 0.25 * tab.shape[1] + 3), max(4, 0.25 * tab.shape[0] + 2))
fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)
im = ax.imshow(mat, aspect="auto", cmap=cmap, vmin=0, vmax=1 if normalize != "none" else None)
ax.set_xticks(np.arange(tab.shape[1]))
ax.set_xticklabels(tab.columns.tolist(), rotation=90)
ax.set_yticks(np.arange(tab.shape[0]))
ax.set_yticklabels(tab.index.tolist())
crv = res.get("cramers_v", np.nan)
ari = res.get("ari", np.nan)
nmi = res.get("nmi", np.nan)
ttl = title or f"{key1} vs {key2}"
ttl += f" | Cramér’s V={crv:.3f}"
if np.isfinite(ari):
ttl += f" | ARI={ari:.3f}"
if np.isfinite(nmi):
ttl += f" | NMI={nmi:.3f}"
ax.set_title(ttl)
cbar = fig.colorbar(im, ax=ax, pad=0.01)
cbar.set_label("fraction" if normalize != "none" else "count")
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, ax