from __future__ import annotations
from pathlib import Path
from typing import Mapping, Sequence
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from ._style import set_style, _savefig
[docs]
def gsea_bubbleplot(
df_gsea: pd.DataFrame,
*,
pathways: Mapping[str, Sequence[str]] | Sequence[str],
comparison_col: str = "comparison",
term_col: str = "Term",
nes_col: str = "NES",
fdr_col: str = "FDR q-val",
# ordering
comparison_order: Sequence[str] | None = None,
drop_empty_comparisons: bool = True,
# size mapping
size_from: str = "fdr", # "fdr" or "pval" (if you want)
min_q: float = 1e-300,
size_min: float = 10.0,
size_max: float = 350.0,
fdr_floor: float = 1e-50,
size_clip_quantile: float | None = 0.99, # None disables
# color mapping
cmap: str = "RdBu_r",
center: float = 0.0,
vmin: float | None = None,
vmax: float | None = None,
# cosmetics
figsize: tuple[float, float] | None = None,
row_spacing: float = 1.0,
col_spacing: float = 1.0,
row_height: float = 0.32, # inches per row (Scanpy-ish)
col_width: float = 0.32, # inches per column
dot_edgecolor: str = "0.15",
dot_linewidth: float = 0.35,
show_grid: bool = False,
group_label_rotation: float = 90,
xtick_rotation: float = 90,
title: str | None = None,
# output
save: str | Path | None = None,
show: bool = True,
):
"""
Bubble plot matrix for GSEA results.
Rows: comparisons (contrasts)
Cols: pathways (terms)
Color: NES (diverging, centered at `center`)
Size: -log10(FDR q-val) with floor & optional clipping
`pathways` can be:
- dict: {"Immune": [term1, term2], "Metabolism": [term3]}
- list: [term1, term2, ...]
"""
set_style()
if not isinstance(df_gsea, pd.DataFrame):
raise TypeError("df_gsea must be a pandas DataFrame.")
for c in (comparison_col, term_col, nes_col, fdr_col):
if c not in df_gsea.columns:
raise KeyError(
f"'{c}' not found in df_gsea columns. Available: {list(df_gsea.columns)[:30]} ..."
)
# ---- flatten pathways + keep group spans ----
if isinstance(pathways, Mapping):
terms: list[str] = []
spans: list[tuple[int, int, str]] = []
start = 0
for grp, lst in pathways.items():
lst = [str(x) for x in lst]
terms.extend(lst)
end = start + len(lst)
spans.append((start, end, str(grp)))
start = end
else:
terms = [str(x) for x in pathways]
spans = []
if len(terms) == 0:
raise ValueError("No pathways provided.")
# ---- subset and make pivot matrices ----
sub = df_gsea[df_gsea[term_col].astype(str).isin(terms)].copy()
sub[term_col] = sub[term_col].astype(str)
sub[comparison_col] = sub[comparison_col].astype(str)
if comparison_order is None:
comps = list(pd.Categorical(sub[comparison_col]).categories) if sub.shape[0] else []
if not comps:
comps = sorted(df_gsea[comparison_col].astype(str).unique().tolist())
else:
comps = [str(x) for x in comparison_order]
nes_mat = sub.pivot_table(index=comparison_col, columns=term_col, values=nes_col, aggfunc="mean")
q_mat = sub.pivot_table(index=comparison_col, columns=term_col, values=fdr_col, aggfunc="mean")
nes_mat = nes_mat.reindex(index=comps, columns=terms)
q_mat = q_mat.reindex(index=comps, columns=terms)
if drop_empty_comparisons:
keep = ~(nes_mat.isna().all(axis=1))
nes_mat = nes_mat.loc[keep]
q_mat = q_mat.loc[keep]
comps_final = nes_mat.index.tolist()
if len(comps_final) == 0:
raise ValueError("No comparisons have any of the selected pathways in df_gsea (after filtering).")
# ---- size: -log10(q) with floor & clipping ----
q_vals = q_mat.to_numpy(dtype=float)
q_vals = np.where(np.isfinite(q_vals), q_vals, np.nan)
# floor: treat 0 or negative as fdr_floor
q_vals = np.where(q_vals <= 0, fdr_floor, q_vals)
q_vals = np.clip(q_vals, fdr_floor, 1.0)
size_signal = -np.log10(q_vals)
if size_clip_quantile is not None:
cap = np.nanquantile(size_signal[np.isfinite(size_signal)], float(size_clip_quantile))
size_signal = np.minimum(size_signal, cap)
finite = np.isfinite(size_signal)
if finite.any():
smin = float(np.nanmin(size_signal[finite]))
smax = float(np.nanmax(size_signal[finite]))
if smax == smin:
smax = smin + 1.0
u = (size_signal - smin) / (smax - smin)
sizes = size_min + (size_max - size_min) * np.clip(u, 0, 1)
else:
smin, smax = 0.0, 1.0
sizes = np.full_like(size_signal, size_min, dtype=float)
# missing NES → do not draw dot
nes_vals = nes_mat.to_numpy(dtype=float)
sizes[~np.isfinite(nes_vals)] = 0.0
# ---- color scaling (CRITICAL: same norm used by scatter + colorbar) ----
nes_all = nes_vals[np.isfinite(nes_vals)]
if nes_all.size == 0:
# fallback
vmin_eff, vmax_eff = -1.0, 1.0
else:
if vmin is None or vmax is None:
vmax_abs = float(np.nanmax(np.abs(nes_all)))
vmax_abs = vmax_abs if vmax_abs > 0 else 1.0
vmin_eff, vmax_eff = center - vmax_abs, center + vmax_abs
else:
vmin_eff, vmax_eff = float(vmin), float(vmax)
# if user gives asymmetric bounds but wants centering, you can still
# keep TwoSlopeNorm; it will work, but not symmetric.
# We'll keep as given.
norm = mpl.colors.TwoSlopeNorm(vmin=vmin_eff, vcenter=float(center), vmax=vmax_eff)
cmap_obj = mpl.cm.get_cmap(cmap)
# ---- autosize figure ----
if figsize is None:
w = max(6.0, col_width * len(terms) + 3.2) # room for legends
h = max(3.5, row_height * len(comps_final) + 2.0) # room for labels
figsize = (w, h)
fig, ax = plt.subplots(figsize=figsize, constrained_layout=False)
fig.subplots_adjust(right=0.78)
# coords
xs = np.arange(len(terms), dtype=float) * float(col_spacing)
ys = np.arange(len(comps_final), dtype=float) * float(row_spacing)
# ---- draw dots (IMPORTANT: c must be numeric NES, not RGBA) ----
# We draw all points in one scatter so colorbar matches exactly.
XX, YY = np.meshgrid(xs, ys)
XX = XX.ravel()
YY = YY.ravel()
S = sizes.ravel()
C = nes_vals.ravel()
mask = np.isfinite(C) & (S > 0)
sc = ax.scatter(
XX[mask],
YY[mask],
s=S[mask],
c=C[mask], # numeric NES
cmap=cmap_obj,
norm=norm,
edgecolors=dot_edgecolor,
linewidths=dot_linewidth,
)
# axes ticks
ax.set_xticks(xs)
ax.set_yticks(ys)
ax.set_xticklabels(terms, rotation=xtick_rotation, ha="right")
ax.set_yticklabels(comps_final)
# limits
if len(xs) == 1:
ax.set_xlim(xs[0] - 0.6 * col_spacing, xs[0] + 0.6 * col_spacing)
else:
ax.set_xlim(xs.min() - 0.6 * col_spacing, xs.max() + 0.6 * col_spacing)
if len(ys) == 1:
ax.set_ylim(ys[0] - 0.6 * row_spacing, ys[0] + 0.6 * row_spacing)
else:
ax.set_ylim(ys.min() - 0.6 * row_spacing, ys.max() + 0.6 * row_spacing)
ax.invert_yaxis() # scanpy-like
if show_grid:
ax.grid(True, linewidth=0.4, alpha=0.35)
else:
ax.grid(False)
ax.set_xlabel("")
ax.set_ylabel("")
if title:
ax.set_title(title)
# ---- pathway group brackets (if dict provided) ----
if spans:
top_y = ys.min() - 1.1 * row_spacing
for start, end, label in spans:
x1 = xs[start] - 0.5 * col_spacing
x2 = xs[end - 1] + 0.5 * col_spacing
ax.plot(
[x1, x1, x2, x2],
[top_y + 0.2 * row_spacing, top_y, top_y, top_y + 0.2 * row_spacing],
lw=1.0,
color="0.2",
clip_on=False,
)
ax.text(
(x1 + x2) / 2,
top_y - 0.15 * row_spacing,
label,
ha="center",
va="bottom",
rotation=group_label_rotation,
clip_on=False,
)
# ---- colorbar (now matches bubbles) ----
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cax = inset_axes(
ax,
width="3%",
height="55%",
loc="upper left",
bbox_to_anchor=(1.02, 0.25, 1, 1),
bbox_transform=ax.transAxes,
borderpad=0,
)
cbar = fig.colorbar(sc, cax=cax)
cbar.set_label("NES")
# ---- size legend (uses the same scaling, including floor/clipping) ----
ref_q = np.array([0.05, 0.01, 0.001, fdr_floor], dtype=float)
ref_q = np.clip(ref_q, fdr_floor, 1.0)
ref_sig = -np.log10(ref_q)
if size_clip_quantile is not None and finite.any():
cap = np.nanquantile((-np.log10(q_vals[np.isfinite(q_vals)])), float(size_clip_quantile))
ref_sig = np.minimum(ref_sig, cap)
if finite.any():
uref = (ref_sig - smin) / (smax - smin)
sref = size_min + (size_max - size_min) * np.clip(uref, 0, 1)
else:
sref = np.full_like(ref_sig, size_min, dtype=float)
labels = [f"q={q:g}" for q in ref_q[:-1]] + [f"q≤{fdr_floor:g}"]
handles = [
plt.Line2D(
[0],
[0],
marker="o",
linestyle="none",
markerfacecolor="0.6",
markeredgecolor=dot_edgecolor,
markersize=float(np.sqrt(sr)), # area->approx marker size
label=lab,
)
for sr, lab in zip(sref, labels)
]
ax.legend(
handles=handles,
title="-log10(FDR)",
bbox_to_anchor=(1.02, 0.22),
loc="upper left",
frameon=False,
)
if save is not None:
_savefig(fig, save)
if show:
plt.show()
return fig, ax