from __future__ import annotations
from pathlib import Path
from typing import Sequence, Literal
import numpy as np
import pandas as pd
import scipy.sparse as sp
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec
try:
from scipy.cluster.hierarchy import linkage, dendrogram
except Exception: # pragma: no cover
linkage = None
dendrogram = None
import anndata as ad
from ._style import set_style, _savefig
# -----------------------------------------------------------------------------
# Dendrogram helpers (aligned to 0.5-centered matrix coordinates)
# -----------------------------------------------------------------------------
def _scipy_leafpos_to_index(v: np.ndarray) -> np.ndarray:
"""
SciPy dendrogram leaf positions come in a 5,15,25,... spacing.
Convert to 0,1,2,... indices.
"""
v = np.asarray(v, dtype=float)
return (v - 5.0) / 10.0
def _plot_row_dendrogram_aligned(
ax: plt.Axes,
Z: np.ndarray,
n_leaves: int,
*,
color: str = "0.4",
lw: float = 1.2,
invert_y: bool = True,
mirror_x: bool = False,
) -> None:
"""
Draw a row dendrogram aligned to row centers at y = 0.5, 1.5, ..., n-0.5.
"""
dd = dendrogram(Z, orientation="right", no_plot=True)
max_x = 0.0
if mirror_x:
for x in dd["dcoord"]:
max_x = max(max_x, float(np.max(x)))
for x, y in zip(dd["dcoord"], dd["icoord"]):
yy = _scipy_leafpos_to_index(np.asarray(y)) + 0.5 # align to centers
xx = np.asarray(x, dtype=float)
if mirror_x:
xx = max_x - xx
ax.plot(xx, yy, color=color, lw=lw)
ax.set_ylim(0.0, float(n_leaves))
if invert_y:
ax.invert_yaxis()
ax.set_xticks([])
ax.set_yticks([])
for spn in ax.spines.values():
spn.set_visible(False)
def _plot_col_dendrogram_aligned(
ax: plt.Axes,
Z: np.ndarray,
n_leaves: int,
*,
color: str = "0.4",
lw: float = 1.2,
) -> None:
"""
Draw a column dendrogram aligned to column centers at x = 0.5, 1.5, ..., n-0.5.
"""
dd = dendrogram(Z, orientation="top", no_plot=True)
for x, y in zip(dd["icoord"], dd["dcoord"]):
xx = _scipy_leafpos_to_index(np.asarray(x)) + 0.5 # align to centers
ax.plot(xx, y, color=color, lw=lw)
ax.set_xlim(0.0, float(n_leaves))
ax.set_xticks([])
ax.set_yticks([])
for spn in ax.spines.values():
spn.set_visible(False)
# -----------------------------------------------------------------------------
# Data extraction
# -----------------------------------------------------------------------------
def _get_feature_matrix(
adata: ad.AnnData,
*,
features: Sequence[str],
layer: str | None,
use_obs: bool,
) -> np.ndarray:
"""
Returns (n_obs x n_features) numeric matrix for either:
- var/features from X/layers (use_obs=False)
- obs columns (use_obs=True)
"""
feats = [str(f) for f in features]
if use_obs:
missing = [f for f in feats if f not in adata.obs.columns]
if missing:
raise KeyError(f"obs features not found in adata.obs (first 10): {missing[:10]}")
M = adata.obs[feats].apply(pd.to_numeric, errors="coerce").to_numpy(float)
return M
missing = [g for g in feats if g not in adata.var_names]
if missing:
raise KeyError(f"Genes not found in adata.var_names (first 10): {missing[:10]}")
X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
gidx = [adata.var_names.get_loc(g) for g in feats]
M = X[:, gidx].toarray() if sp.issparse(X) else np.asarray(X[:, gidx], dtype=float)
return M
# -----------------------------------------------------------------------------
# Main function
# -----------------------------------------------------------------------------
[docs]
def dotplot(
adata: ad.AnnData,
*,
# ---- what to plot ----
var_names: Sequence[str] | None = None,
var_groups: dict[str, Sequence[str]] | None = None,
# allow plotting obs (e.g., metaprograms stored in adata.obs["MP_*"])
obs_names: Sequence[str] | None = None,
obs_groups: dict[str, Sequence[str]] | None = None,
use_obs: bool = False,
# ---- grouping ----
groupby: str | Sequence[str] = "leiden",
# ---- expression sources ----
layer: str | None = "log1p_cpm",
fraction_layer: str | None = "counts",
# Scanpy-style naming:
expression_cutoff: float = 0.0,
mean_only_expressed: bool = False,
# Back-compat alias (will override expression_cutoff if provided as non-default):
expr_threshold: float | None = None,
# Scanpy-style standard_scale: "var" or "group" performs MIN-MAX scaling [0,1]
standard_scale: Literal["var", "group", "zscore_var", "zscore_group"] | None = None,
# ---- layout ----
swap_axes: bool = False,
row_spacing: float = 1.0,
dendrogram_top: bool = False,
dendrogram_rows: bool = False,
row_dendrogram_position: Literal["right", "left", "outer_left"] = "right",
cluster_rows: bool | None = None,
cluster_cols: bool | None = None,
# ---- coloring ----
cmap: str = "Reds",
vmin: float | None = None,
vmax: float | None = None,
# ---- size encoding ----
dot_min: float | None = None,
dot_max: float | None = None,
# Scanpy-style name:
size_exponent: float = 1.5,
# Back-compat alias: if gamma is provided, it overrides size_exponent
gamma: float | None = None,
smallest_dot: float = 0.0,
largest_dot: float = 200.0,
# AUTO scaling to figsize/axes box
scale_dots_to_fig: bool = True,
dot_scale: float = 1.0,
# Padding (Scanpy-like; units are "tick distance" where 1.0 is one cell)
x_padding: float = 0.8,
y_padding: float = 1.0,
# ---- figure ----
figsize: tuple[float, float] | None = None,
invert_yaxis: bool = True,
title: str | None = None,
size_title: str = "Fraction of samples\nin group (%)",
colorbar_title: str = "Mean expression\nin group",
# NEW: allow size to represent an obs column directly (e.g. purity)
size_obs_key: str | None = None,
size_clip: tuple[float, float] | None = None,
# ---- IO ----
save: str | Path | None = None,
show: bool = True,
):
"""
Dotplot like Scanpy (0.5-centered coordinates, dot scaling, minmax standard_scale),
with BULLKpy additions:
- can plot var genes OR numeric adata.obs columns (use_obs=True)
- dot sizes can auto-scale to figsize/axes (scale_dots_to_fig=True)
- optional dendrograms (top + rows)
- optional size encoding from an obs column (size_obs_key)
"""
set_style()
if (dendrogram_top or dendrogram_rows) and (linkage is None or dendrogram is None):
raise ImportError("Dendrograms require scipy (scipy.cluster.hierarchy).")
if cluster_rows is None:
cluster_rows = dendrogram_rows
if cluster_cols is None:
cluster_cols = dendrogram_top
# cutoff alias handling
if expr_threshold is not None:
expression_cutoff = float(expr_threshold)
# exponent alias handling
if gamma is not None:
size_exponent = float(gamma)
# ---- groupby ----
if isinstance(groupby, (list, tuple)):
for g in groupby:
if g not in adata.obs.columns:
raise KeyError(f"groupby='{g}' not found in adata.obs")
grp_df = adata.obs[list(groupby)].copy()
grp_key = grp_df.astype(str).agg(" | ".join, axis=1)
groups = pd.Categorical(grp_key)
else:
if groupby not in adata.obs.columns:
raise KeyError(f"groupby='{groupby}' not found in adata.obs")
groups = adata.obs[groupby].astype("category")
cats = list(pd.Categorical(groups).categories)
# ---- features (var or obs) ----
if use_obs:
if obs_groups is not None:
ordered: list[str] = []
for _, feats in obs_groups.items():
ordered.extend([str(f) for f in feats])
obs_names = ordered
elif obs_names is None:
raise ValueError("With use_obs=True, provide obs_names=... or obs_groups={...}.")
feature_names = [str(v) for v in obs_names]
else:
if var_groups is not None:
ordered = []
for _, genes in var_groups.items():
ordered.extend([str(g) for g in genes])
var_names = ordered
elif var_names is None:
raise ValueError("Provide either var_names=... or var_groups={...}.")
feature_names = [str(v) for v in var_names]
# ---- matrices ----
# mean-expression matrix
M_mean = _get_feature_matrix(adata, features=feature_names, layer=layer, use_obs=use_obs)
# size matrix: either fraction > cutoff OR an obs key
if size_obs_key is not None:
if size_obs_key not in adata.obs.columns:
raise KeyError(f"size_obs_key='{size_obs_key}' not found in adata.obs")
sraw = pd.to_numeric(adata.obs[size_obs_key], errors="coerce").to_numpy(float)
if size_clip is not None:
lo, hi = map(float, size_clip)
sraw = np.clip(sraw, lo, hi)
# broadcast per-feature (same size for all features within each sample)
M_size = np.repeat(sraw[:, None], repeats=len(feature_names), axis=1)
size_mode = "mean"
else:
if use_obs:
# for obs features, define "fraction expressed" as proportion above cutoff
M_frac = M_mean.copy()
else:
X_frac = (
adata.layers[fraction_layer]
if (fraction_layer is not None and fraction_layer in adata.layers)
else (adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X)
)
gidx = [adata.var_names.get_loc(g) for g in feature_names]
M_frac = X_frac[:, gidx].toarray() if sp.issparse(X_frac) else np.asarray(X_frac[:, gidx], dtype=float)
M_size = M_frac
size_mode = "fraction"
# ---- aggregate (groups x features) ----
groups_cat = pd.Categorical(groups, categories=cats, ordered=True)
mean_expr = np.zeros((len(cats), len(feature_names)), dtype=float)
frac_expr = np.zeros((len(cats), len(feature_names)), dtype=float)
for i, c in enumerate(cats):
mask = (groups_cat == c)
Xg = M_mean[mask, :]
# mean expression: optionally over expressed only
if mean_only_expressed and size_obs_key is None:
# "expressed" defined by expression_cutoff on the SAME value used for mean
# (if you want fraction_layer to define expression, keep mean_only_expressed=False)
expressed = Xg > float(expression_cutoff)
with np.errstate(invalid="ignore", divide="ignore"):
num = np.nansum(np.where(expressed, Xg, 0.0), axis=0)
den = np.nansum(expressed.astype(float), axis=0)
den = np.where(den == 0, np.nan, den)
mean_expr[i, :] = num / den
mean_expr[i, :] = np.nan_to_num(mean_expr[i, :], nan=0.0)
else:
mean_expr[i, :] = np.nanmean(Xg, axis=0)
# size encoding
if size_mode == "mean":
frac_expr[i, :] = np.nanmean(M_size[mask, :], axis=0)
else:
frac_expr[i, :] = np.nanmean(M_size[mask, :] > float(expression_cutoff), axis=0)
# ---- standard_scale (Scanpy-like minmax) ----
disp = mean_expr.copy()
if standard_scale == "group":
# per-row (group) minmax to [0,1]
mn = np.nanmin(disp, axis=1, keepdims=True)
mx = np.nanmax(disp, axis=1, keepdims=True)
denom = np.where((mx - mn) == 0, 1.0, (mx - mn))
disp = (disp - mn) / denom
disp = np.nan_to_num(disp, nan=0.0)
elif standard_scale == "var":
# per-col (feature) minmax to [0,1]
mn = np.nanmin(disp, axis=0, keepdims=True)
mx = np.nanmax(disp, axis=0, keepdims=True)
denom = np.where((mx - mn) == 0, 1.0, (mx - mn))
disp = (disp - mn) / denom
disp = np.nan_to_num(disp, nan=0.0)
# Optional legacy z-score modes (if you still want them)
elif standard_scale == "zscore_var":
mu = np.nanmean(disp, axis=0, keepdims=True)
sd = np.nanstd(disp, axis=0, ddof=0, keepdims=True)
sd = np.where(sd == 0, 1.0, sd)
disp = (disp - mu) / sd
elif standard_scale == "zscore_group":
mu = np.nanmean(disp, axis=1, keepdims=True)
sd = np.nanstd(disp, axis=1, ddof=0, keepdims=True)
sd = np.where(sd == 0, 1.0, sd)
disp = (disp - mu) / sd
elif standard_scale is None:
pass
else:
raise ValueError("standard_scale must be one of: None, 'var', 'group', 'zscore_var', 'zscore_group'")
# ---- plotted matrix (swap axes) ----
if swap_axes:
plot_vals = disp.T
plot_frac = frac_expr.T
row_labels = list(feature_names) # features
col_labels = list(cats) # groups
else:
plot_vals = disp
plot_frac = frac_expr
row_labels = list(cats)
col_labels = list(feature_names)
plot_vals = np.asarray(plot_vals, dtype=float)
plot_frac = np.asarray(plot_frac, dtype=float)
# ---- clustering (orders) ----
row_order = np.arange(plot_vals.shape[0])
col_order = np.arange(plot_vals.shape[1])
Z_row = None
Z_col = None
if cluster_rows and plot_vals.shape[0] > 2:
Z_row = linkage(plot_vals, method="average", metric="euclidean")
row_order = np.array(dendrogram(Z_row, no_plot=True)["leaves"], dtype=int)
if cluster_cols and plot_vals.shape[1] > 2:
Z_col = linkage(plot_vals.T, method="average", metric="euclidean")
col_order = np.array(dendrogram(Z_col, no_plot=True)["leaves"], dtype=int)
plot_vals = plot_vals[row_order, :][:, col_order]
plot_frac = plot_frac[row_order, :][:, col_order]
row_labels = [row_labels[i] for i in row_order]
col_labels = [col_labels[i] for i in col_order]
# ---- autosize figure ----
if figsize is None:
n_x = len(col_labels)
n_y = len(row_labels)
w = max(5.8, 0.50 * n_x + 3.6)
h = max(4.2, 0.42 * n_y + 2.6)
figsize = (w, h)
# ---- color scaling ----
vmin_eff = vmin if vmin is not None else float(np.nanmin(plot_vals))
vmax_eff = vmax if vmax is not None else float(np.nanmax(plot_vals))
norm = mpl.colors.Normalize(vmin=vmin_eff, vmax=vmax_eff)
cmap_obj = mpl.colormaps.get_cmap(cmap) if hasattr(mpl, "colormaps") else mpl.cm.get_cmap(cmap)
has_row_dendro = bool(dendrogram_rows and (Z_row is not None) and (len(row_labels) > 2))
has_top_dendro = bool(dendrogram_top and (Z_col is not None) and (len(col_labels) > 2))
# ---- dot_min/dot_max (Scanpy style) ----
frac = np.asarray(plot_frac, float)
frac_flat = frac[np.isfinite(frac)]
if frac_flat.size == 0:
frac_flat = np.array([0.0], dtype=float)
if dot_max is None:
# ceil to nearest 0.1 like Scanpy for nicer legends
dot_max_eff = float(np.ceil(np.nanmax(frac_flat) * 10.0) / 10.0)
else:
dot_max_eff = float(dot_max)
if dot_min is None:
dot_min_eff = 0.0
else:
dot_min_eff = float(dot_min)
if not (0.0 <= dot_min_eff <= 1.0 and 0.0 <= dot_max_eff <= 1.0 and dot_min_eff <= dot_max_eff):
raise ValueError("dot_min/dot_max must satisfy 0<=dot_min<=dot_max<=1")
# clip + rescale to [0,1]
f = np.clip(frac, dot_min_eff, dot_max_eff)
if (dot_max_eff - dot_min_eff) > 0:
u = (f - dot_min_eff) / (dot_max_eff - dot_min_eff)
else:
u = np.zeros_like(f, dtype=float)
u = np.clip(u, 0.0, 1.0) ** float(size_exponent)
# ---- layout ----
outer_left = 0.20 if (has_row_dendro and row_dendrogram_position == "outer_left") else 0.001
inner_left = 0.18 if (has_row_dendro and row_dendrogram_position == "left") else 0.001
inner_right = 0.18 if (has_row_dendro and row_dendrogram_position == "right") else 0.001
legends_w = 0.58
fig = plt.figure(figsize=figsize, constrained_layout=False)
gs = GridSpec(
nrows=2,
ncols=5,
figure=fig,
height_ratios=[0.22, 1.0] if has_top_dendro else [0.001, 1.0],
width_ratios=[outer_left, inner_left, 1.0, inner_right, legends_w],
hspace=0.05,
wspace=0.14,
)
ax_top = fig.add_subplot(gs[0, 2]) if has_top_dendro else None
ax = fig.add_subplot(gs[1, 2])
ax_outer_left = fig.add_subplot(gs[1, 0]) if (has_row_dendro and row_dendrogram_position == "outer_left") else None
ax_inner_left = fig.add_subplot(gs[1, 1]) if (has_row_dendro and row_dendrogram_position == "left") else None
ax_inner_right = fig.add_subplot(gs[1, 3]) if (has_row_dendro and row_dendrogram_position == "right") else None
gs_leg = gs[:, 4].subgridspec(nrows=2, ncols=1, height_ratios=[0.58, 0.42], hspace=0.25)
ax_leg = fig.add_subplot(gs_leg[0, 0])
ax_cbar = fig.add_subplot(gs_leg[1, 0])
# ---- dot sizes (AUTO scale to figure/axes) ----
if scale_dots_to_fig:
# compute available point space for the dot grid from the axes box
fig.canvas.draw_idle()
pos = ax.get_position()
ax_w_pts = figsize[0] * 72.0 * pos.width
ax_h_pts = figsize[1] * 72.0 * pos.height
nx = max(1, len(col_labels))
ny = max(1, len(row_labels))
cell_w = ax_w_pts / nx
cell_h = ax_h_pts / ny
max_diam = 0.85 * min(cell_w, cell_h)
min_diam = 0.22 * min(cell_w, cell_h)
largest_dot_eff = (max_diam ** 2) * float(dot_scale)
smallest_dot_eff = (min_diam ** 2) * float(dot_scale)
sizes = smallest_dot_eff + (largest_dot_eff - smallest_dot_eff) * u
else:
smallest_dot_eff = float(smallest_dot)
largest_dot_eff = float(largest_dot)
sizes = smallest_dot_eff + (largest_dot_eff - smallest_dot_eff) * u
# ---- main dots (Scanpy-style coordinates) ----
n_rows, n_cols = plot_vals.shape
yy, xx = np.indices((n_rows, n_cols))
x = xx.ravel().astype(float) + 0.5
y = yy.ravel().astype(float) + 0.5
s = sizes.ravel().astype(float)
c = cmap_obj(norm(plot_vals.ravel().astype(float)))
ax.scatter(
x,
y,
s=s,
c=c,
edgecolors="0.2",
linewidths=0.35,
)
# ticks at centers
x_ticks = np.arange(n_cols, dtype=float) + 0.5
y_ticks = np.arange(n_rows, dtype=float) + 0.5
ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)
ax.set_xticklabels([str(v) for v in col_labels], rotation=90, ha="center")
ax.set_yticklabels([str(v) for v in row_labels])
# axis limits (Scanpy-style padding when color_on="dot")
xpad = float(x_padding) - 0.5
ypad = float(y_padding) - 0.5
ax.set_xlim(-xpad, float(n_cols) + xpad)
ax.set_ylim(float(n_rows) + ypad, -ypad) # inverted like Scanpy
if not invert_yaxis:
# user wants non-inverted y: flip back
ax.set_ylim(-ypad, float(n_rows) + ypad)
# reduce y tick label padding so it doesn't crash into dendrogram area
ax.tick_params(axis="y", pad=10 if has_row_dendro else 6)
# x tick placement: keep bottom labels by default (more typical for bulk figures)
ax.tick_params(axis="x", labeltop=False, labelbottom=True)
ax.set_xlabel("")
ax.set_ylabel("")
if title is not None:
ax.set_title(title)
# ---- dendrograms ----
if ax_top is not None:
_plot_col_dendrogram_aligned(ax_top, Z_col, n_leaves=len(col_labels))
# mirror_x=True flips the row dendrogram horizontally (so it “points” to the matrix)
if ax_outer_left is not None:
_plot_row_dendrogram_aligned(ax_outer_left, Z_row, n_leaves=len(row_labels), invert_y=invert_yaxis, mirror_x=True)
if ax_inner_left is not None:
_plot_row_dendrogram_aligned(ax_inner_left, Z_row, n_leaves=len(row_labels), invert_y=invert_yaxis, mirror_x=True)
if ax_inner_right is not None:
_plot_row_dendrogram_aligned(ax_inner_right, Z_row, n_leaves=len(row_labels), invert_y=invert_yaxis, mirror_x=False)
# ---- legends ----
ax_leg.axis("off")
ax_leg.text(0.0, 1.00, size_title, ha="left", va="top", transform=ax_leg.transAxes)
# legend reference points (fractions)
ref = np.array([0.2, 0.4, 0.6, 0.8, 1.0], dtype=float)
ref = np.clip(ref, dot_min_eff, dot_max_eff)
if (dot_max_eff - dot_min_eff) > 0:
ref_u = ((ref - dot_min_eff) / (dot_max_eff - dot_min_eff)) ** float(size_exponent)
else:
ref_u = np.zeros_like(ref)
ref_s = smallest_dot_eff + (largest_dot_eff - smallest_dot_eff) * ref_u
x0, y0, dx = 0.12, 0.55, 0.16
for j, rs in enumerate(ref_s):
ax_leg.scatter(
[x0 + j * dx],
[y0],
s=float(rs),
color="0.55",
edgecolors="0.2",
linewidths=0.3,
transform=ax_leg.transAxes,
)
if size_obs_key is None:
ax_leg.text(0.12, 0.25, "20 40 60 80 100", ha="left", va="center", transform=ax_leg.transAxes)
else:
ax_leg.text(0.12, 0.25, "low high", ha="left", va="center", transform=ax_leg.transAxes)
sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap_obj)
cb = fig.colorbar(sm, cax=ax_cbar, orientation="horizontal")
cb.set_label(colorbar_title)
# ---- margins (make room for long y-labels) ----
if swap_axes:
left = 0.42 if has_row_dendro else 0.30
bottom = 0.18
else:
left = 0.32 if has_row_dendro else 0.20
bottom = 0.14
fig.subplots_adjust(left=left, right=0.98, top=0.92, bottom=bottom)
# --- shrink legend colorbar height ---
def _shrink_axis_box(a: plt.Axes, height_factor: float = 0.2) -> None:
if a is None:
return
pos = a.get_position()
new_h = pos.height * float(height_factor)
a.set_position([pos.x0, pos.y0 + (pos.height - new_h), pos.width, new_h])
_shrink_axis_box(ax_cbar, height_factor=0.2)
# ---- optional row spacing compression (kept for compatibility) ----
def _shrink_axis_height(a: plt.Axes, factor: float) -> None:
if a is None:
return
pos = a.get_position()
new_h = pos.height * float(factor)
a.set_position([pos.x0, pos.y0 + (pos.height - new_h), pos.width, new_h])
if float(row_spacing) != 1.0:
_shrink_axis_height(ax, float(row_spacing))
_shrink_axis_height(ax_outer_left, float(row_spacing))
_shrink_axis_height(ax_inner_left, float(row_spacing))
_shrink_axis_height(ax_inner_right, float(row_spacing))
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, ax