from __future__ import annotations
from pathlib import Path
from typing import Sequence, Literal
import numpy as np
import pandas as pd
import scipy.sparse as sp
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec
import anndata as ad
from ._style import set_style, _savefig
from ..logging import warn
def _get_expr_matrix(
adata: ad.AnnData,
genes: Sequence[str],
*,
layer: str | None = "log1p_cpm",
) -> pd.DataFrame:
missing = [g for g in genes if g not in adata.var_names]
if missing:
raise KeyError(f"Genes not in adata.var_names (first 10): {missing[:10]}")
X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
gidx = [adata.var_names.get_loc(g) for g in genes]
M = X[:, gidx].toarray() if sp.issparse(X) else np.asarray(X[:, gidx], dtype=float)
return pd.DataFrame(M, index=adata.obs_names.astype(str), columns=[str(g) for g in genes])
def _binary_from_obs(adata: ad.AnnData, cols: Sequence[str]) -> pd.DataFrame:
miss = [c for c in cols if c not in adata.obs.columns]
if miss:
raise KeyError(f"Mutation columns not found in adata.obs (first 10): {miss[:10]}")
df = adata.obs[list(cols)].copy()
out = pd.DataFrame(index=df.index.astype(str))
for c in cols:
s = df[c]
if str(s.dtype) == "bool":
v = s.astype(int)
else:
x = pd.to_numeric(s, errors="coerce")
if x.isna().mean() > 0.2:
v = (
s.astype(str)
.str.lower()
.map({"1": 1, "0": 0, "true": 1, "false": 0})
.fillna(0)
.astype(int)
)
else:
v = x.fillna(0).astype(int)
out[c] = v.clip(lower=0, upper=1).astype(int)
return out
def _sort_genes_by_freq(mut_samp_x_gene: pd.DataFrame) -> list[str]:
freq = mut_samp_x_gene.mean(axis=0).sort_values(ascending=False)
return freq.index.astype(str).tolist()
def _sort_samples_mutation_first(mut_samp_x_gene: pd.DataFrame, gene_order: Sequence[str]) -> pd.Index:
"""
cBioPortal-like:
Sort samples so that mutants in gene_order[0] are left,
then within those, mutants in gene_order[1] are left, etc.
Equivalent to lex sort on the mutation bit-vectors (mutants first).
"""
M = mut_samp_x_gene.loc[:, list(gene_order)].to_numpy(dtype=int) # (n_samples x n_genes)
burden = M.sum(axis=1)
# lexsort: last key is primary, so build keys reversed.
# We want: gene0 mut first, then gene1 mut first, ... and higher burden first.
keys = [-(burden.astype(int))]
for j in range(M.shape[1] - 1, -1, -1):
keys.append(-M[:, j])
order = np.lexsort(keys) # stable, deterministic
return mut_samp_x_gene.index[order]
def _apply_group_contiguity(
mut: pd.DataFrame,
groups: pd.Series,
*,
group_order: Sequence[str] | None = None,
within_group_sort: Literal["mut_first", "burden", "none"] = "mut_first",
) -> tuple[pd.DataFrame, pd.Series]:
g = groups.loc[mut.index].astype(str)
if group_order is not None:
cats = [str(x) for x in group_order]
gcat = pd.Categorical(g, categories=cats, ordered=True)
else:
gcat = pd.Categorical(g)
blocks = []
g_out = []
for level in list(gcat.categories):
idx = mut.index[gcat == level]
if len(idx) == 0:
continue
sub = mut.loc[idx]
if within_group_sort == "mut_first":
gene_order = list(sub.columns)
idx2 = _sort_samples_mutation_first(sub, gene_order)
sub = sub.loc[idx2]
elif within_group_sort == "burden":
burden = sub.sum(axis=1).sort_values(ascending=False)
sub = sub.loc[burden.index]
elif within_group_sort == "none":
pass
else:
raise ValueError("within_group_sort must be 'mut_first', 'burden', or 'none'.")
blocks.append(sub)
g_out.append(pd.Series([str(level)] * sub.shape[0], index=sub.index))
mut2 = pd.concat(blocks, axis=0) if blocks else mut
g2 = pd.concat(g_out) if g_out else g
return mut2, g2
[docs]
def oncoprint(
adata: ad.AnnData,
*,
mut_cols: Sequence[str],
expr_genes: Sequence[str] | None = None,
layer: str | None = "log1p_cpm",
# ordering
sort_genes: bool = True,
sort_samples: Literal["mut_first", "burden", "none"] = "mut_first",
drop_all_wt: bool = True,
max_samples: int | None = None,
# group blocks
groupby: str | None = None,
group_order: Sequence[str] | None = None,
group_blocks: bool = True,
within_group_sort: Literal["mut_first", "burden", "none"] = "mut_first",
# styling
show_sample_labels: bool = False,
mut_color: str = "#222222",
wt_color: str = "#FFFFFF",
grid_color: str = "0.85",
expr_cmap: str = "viridis",
expr_vmin: float | None = None,
expr_vmax: float | None = None,
expr_zscore: bool = False,
cell_size: float | None = None, # inches per sample (auto if None)
row_height: float = 0.35,
expr_row_height: float = 0.25,
top_annotation_height: float = 0.35,
title: str | None = None,
figsize: tuple[float, float] | None = None, # NEW: override auto size
fontsize: float = 10.0, # NEW: base font size
# output safety
save_dpi: int | None = None, # auto-reduced if needed
max_pixels: int = 60000, # keep < 2^16 per dimension
save: str | Path | None = None,
show: bool = True,
):
"""
Binary oncoprint (mut vs wt) from adata.obs 0/1 columns.
Implements mutation-first sample ordering and optional group blocks.
"""
set_style()
# --- NEW: global font control for this figure ---
fontsize = float(fontsize)
mut = _binary_from_obs(adata, mut_cols) # (samples x genes)
mut.index = mut.index.astype(str)
# drop all-wt samples (C)
if drop_all_wt:
keep = mut.sum(axis=1) > 0
mut = mut.loc[keep]
if mut.shape[0] == 0:
raise ValueError("After drop_all_wt=True, no samples have any mutations in mut_cols.")
# gene order (B)
if sort_genes:
gene_order = _sort_genes_by_freq(mut)
mut = mut[gene_order]
else:
gene_order = list(mut.columns)
# sample cap
if max_samples is not None and mut.shape[0] > int(max_samples):
mut = mut.iloc[: int(max_samples)].copy()
warn(f"Truncated to first max_samples={max_samples} samples for plotting.")
# group info
groups = None
if groupby is not None:
if groupby not in adata.obs.columns:
raise KeyError(f"groupby='{groupby}' not found in adata.obs")
groups = adata.obs.loc[mut.index, groupby].astype(str)
# sample ordering (B) + group blocks (A)
if groups is not None and group_blocks:
# keep groups contiguous, sort inside each block
mut, groups = _apply_group_contiguity(
mut,
groups,
group_order=group_order,
within_group_sort=within_group_sort,
)
else:
if sort_samples == "none":
pass
elif sort_samples == "burden":
burden = mut.sum(axis=1).sort_values(ascending=False)
mut = mut.loc[burden.index]
if groups is not None:
groups = groups.loc[mut.index]
elif sort_samples == "mut_first":
idx = _sort_samples_mutation_first(mut, gene_order=gene_order)
mut = mut.loc[idx]
if groups is not None:
groups = groups.loc[mut.index]
else:
raise ValueError("sort_samples must be 'mut_first', 'burden', or 'none'.")
# expression tracks (optional)
expr_df = None
if expr_genes is not None and len(expr_genes) > 0:
expr_df = _get_expr_matrix(adata, expr_genes, layer=layer).loc[mut.index]
if expr_zscore:
arr = expr_df.to_numpy(dtype=float)
mu = np.nanmean(arr, axis=0, keepdims=True)
sd = np.nanstd(arr, axis=0, keepdims=True)
sd[sd == 0] = 1.0
expr_df = pd.DataFrame((arr - mu) / sd, index=expr_df.index, columns=expr_df.columns)
n_samples = mut.shape[0]
n_genes = mut.shape[1]
n_expr = 0 if expr_df is None else expr_df.shape[1]
# auto cell size to avoid gigantic images (D)
if cell_size is None:
# target a manageable width in inches even for big n
# (still readable, but prevents exploding pixel sizes)
cell_size = float(np.clip(18.0 / max(n_samples, 1), 0.02, 0.22))
# figure size (auto unless figsize provided)
width = max(6.0, cell_size * n_samples + 2.8)
height = (
(top_annotation_height if groups is not None else 0.0)
+ row_height * n_genes
+ (0.35 if n_expr > 0 else 0.0)
+ expr_row_height * n_expr
+ 1.2
)
if figsize is not None:
width, height = map(float, figsize)
with mpl.rc_context({"font.size": fontsize}):
fig = plt.figure(figsize=(width, height), constrained_layout=False)
# apply base fontsize to this figure (matplotlib uses rcParams at draw time)
fig.rcParams = {} # optional safeguard; no-op in most mpl versions
plt.rcParams.update({"font.size": fontsize})
heights = []
if groups is not None:
heights.append(top_annotation_height)
heights.append(row_height * n_genes)
if n_expr > 0:
heights.append(expr_row_height * n_expr + 0.35)
gs = GridSpec(
nrows=len(heights),
ncols=2,
figure=fig,
width_ratios=[1.0, 0.10],
height_ratios=heights,
hspace=0.15,
wspace=0.08,
)
row_i = 0
# --- top group strip + separators (A) ---
ax_top = None
group_boundaries = []
if groups is not None:
ax_top = fig.add_subplot(gs[row_i, 0])
row_i += 1
cats = list(pd.Categorical(groups).categories)
cmap = mpl.cm.get_cmap("tab20" if len(cats) <= 20 else "hsv")
cols = [mpl.colors.to_hex(cmap(i / max(1, len(cats) - 1))) for i in range(len(cats))]
m = {str(c): cols[i] for i, c in enumerate(cats)}
g_list = groups.astype(str).tolist()
for i, g in enumerate(g_list):
ax_top.add_patch(mpl.patches.Rectangle((i, 0), 1, 1, facecolor=m.get(g, "#cccccc"), edgecolor="none"))
# group boundaries for separators
for i in range(1, len(g_list)):
if g_list[i] != g_list[i - 1]:
group_boundaries.append(i)
ax_top.set_xlim(0, n_samples)
ax_top.set_ylim(0, 1)
ax_top.set_xticks([])
ax_top.set_yticks([])
for spn in ax_top.spines.values():
spn.set_visible(False)
# legend
handles = [mpl.patches.Patch(color=m[str(c)], label=str(c)) for c in cats]
ax_top.legend(
handles=handles,
title=groupby,
bbox_to_anchor=(1.22, 0.0),
loc="lower left",
frameon=False,
borderaxespad=0.0,
fontsize=max(7.0, fontsize * 0.85),
title_fontsize=max(8.0, fontsize * 0.90),
)
# --- mutation panel ---
ax = fig.add_subplot(gs[row_i, 0])
row_i += 1
ax.set_facecolor(wt_color)
A = mut.to_numpy(dtype=int).T # genes x samples
for gi in range(n_genes):
y = n_genes - 1 - gi # top gene at top
muts = np.where(A[gi, :] == 1)[0]
if muts.size:
# rasterize rectangles for speed on large cohorts
for x in muts:
ax.add_patch(
mpl.patches.Rectangle((x, y), 1, 1, facecolor=mut_color, edgecolor="none", rasterized=True)
)
# grid
ax.set_xlim(0, n_samples)
ax.set_ylim(0, n_genes)
ax.set_xticks(np.arange(0, n_samples + 1, 1), minor=True)
ax.set_yticks(np.arange(0, n_genes + 1, 1), minor=True)
ax.tick_params(axis="y", labelsize=max(7.0, fontsize * 0.85))
ax.tick_params(axis="x", labelsize=max(7.0, fontsize * 0.75))
ax.grid(which="minor", color=grid_color, linewidth=0.6)
ax.tick_params(which="minor", bottom=False, left=False)
# separators between group blocks
for b in group_boundaries:
ax.axvline(b, color="0.5", lw=1.2, alpha=0.8)
# labels
ax.set_yticks(np.arange(n_genes) + 0.5)
ax.set_yticklabels(list(reversed(mut.columns.astype(str).tolist())))
ax.set_xticks(np.arange(n_samples) + 0.5)
ax.set_xticklabels(
mut.index.astype(str).tolist(),
rotation=90,
fontsize=max(6.0, fontsize * 0.70),
)
if not show_sample_labels:
ax.set_xticklabels([])
ax.tick_params(axis="x", length=0)
ax.tick_params(axis="y", length=0)
ax.set_xlabel("Samples")
ax.set_ylabel("Mutations")
ax.set_title(title or "Oncoprint", pad=8, fontsize=max(10.0, fontsize * 1.15))
# side bar: frequency
ax_bar = fig.add_subplot(gs[(0 if ax_top is None else 1), 1])
freq = mut.mean(axis=0).to_numpy(dtype=float)[::-1]
ax_bar.barh(np.arange(n_genes) + 0.5, freq, height=0.85)
ax_bar.set_ylim(0, n_genes)
ax_bar.set_yticks([])
ax_bar.set_xlabel("Freq")
ax_bar.xaxis.set_label_position("top")
ax_bar.xaxis.tick_top()
ax_bar.set_xlabel("Freq", fontsize=max(8.0, fontsize * 0.85))
ax_bar.tick_params(axis="x", labelsize=max(7.0, fontsize * 0.75))
for spn in ["right", "bottom", "left"]:
ax_bar.spines[spn].set_visible(False)
# expression panel
if expr_df is not None and n_expr > 0:
ax_expr = fig.add_subplot(gs[row_i, 0])
X = expr_df.to_numpy(dtype=float).T # genes x samples
vmin = float(np.nanmin(X)) if expr_vmin is None else float(expr_vmin)
vmax = float(np.nanmax(X)) if expr_vmax is None else float(expr_vmax)
im = ax_expr.imshow(
X,
aspect="auto",
interpolation="nearest",
cmap=expr_cmap,
vmin=vmin,
vmax=vmax,
)
for b in group_boundaries:
ax_expr.axvline(b - 0.5, color="0.5", lw=1.2, alpha=0.8)
ax_expr.set_yticks(np.arange(n_expr))
ax_expr.set_yticklabels(expr_df.columns.astype(str).tolist(), fontsize=max(7.0, fontsize * 0.85))
ax_expr.set_xticks([])
ax_expr.tick_params(axis="y", length=0)
# compact colorbar
cax = fig.add_axes([0.92, 0.12, 0.015, 0.16])
cb = fig.colorbar(im, cax=cax)
cb.set_label(("Z-score" if expr_zscore else "Expression"), rotation=90, fontsize=max(8.0, fontsize * 0.85))
cb.ax.tick_params(labelsize=max(7.0, fontsize * 0.75))
fig.subplots_adjust(left=0.26, right=0.86, top=0.92, bottom=0.10)
# ---- SAFE SAVE (D): avoid matplotlib 2^16 px limit ----
if save is not None:
# estimate current dpi
current_dpi = int(mpl.rcParams.get("savefig.dpi", mpl.rcParams.get("figure.dpi", 150)))
dpi_use = current_dpi if save_dpi is None else int(save_dpi)
# enforce pixel limits
px_w = int(width * dpi_use)
px_h = int(height * dpi_use)
if px_w > int(max_pixels) or px_h > int(max_pixels):
dpi_w = int(max(20, max_pixels // max(width, 1e-6)))
dpi_h = int(max(20, max_pixels // max(height, 1e-6)))
dpi_use = int(max(20, min(dpi_use, dpi_w, dpi_h)))
warn(f"Reducing save DPI to {dpi_use} to avoid huge image ({px_w}x{px_h}px).")
# save directly with dpi override
save = Path(save)
save.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save, dpi=dpi_use, bbox_inches="tight")
if show:
plt.show()
return fig, ax