Source code for bullkpy.pl.gsea_leading_edge

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Sequence, Literal

import numpy as np
import pandas as pd

try:
    import seaborn as sns  # type: ignore
except Exception:
    sns = None

import matplotlib as mpl
import matplotlib.pyplot as plt

try:
    import scipy.sparse as sp  # type: ignore
except Exception:
    sp = None

try:
    from scipy.cluster.hierarchy import linkage, fcluster  # type: ignore
    from scipy.spatial.distance import squareform  # type: ignore
except Exception:
    linkage = None
    fcluster = None
    squareform = None

try:
    from sklearn.manifold import MDS
except Exception:
    MDS = None

# ---------------------------------------------------------------------
# Your existing style helpers (adapt as needed)
# ---------------------------------------------------------------------
def set_style():
    """Minimal style hook (use your existing bullkpy set_style if available)."""
    try:
        plt.rcParams["figure.dpi"] = 120
    except Exception:
        pass


def _savefig(fig, save: str | Path):
    save = Path(save)
    save.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(save, bbox_inches="tight")


def _rotate_gene_labels(ax: plt.Axes, fontsize: float = 7.0):
    for lab in ax.get_xticklabels():
        lab.set_rotation(90)
        lab.set_ha("center")
        lab.set_va("top")
        lab.set_fontsize(float(fontsize))


# ---------------------------------------------------------------------
# GSEA result handling (supports pre_res object OR DataFrame)
# ---------------------------------------------------------------------
def _get_res2d(pre_res) -> pd.DataFrame:
    """
    Accepts either:
      - gseapy prerank result object with .res2d
      - pandas DataFrame already (res2d-like)
    """
    if pre_res is None:
        raise ValueError("pre_res cannot be None")
    if isinstance(pre_res, pd.DataFrame):
        return pre_res
    if hasattr(pre_res, "res2d"):
        df = pre_res.res2d
        if not isinstance(df, pd.DataFrame):
            raise TypeError("pre_res.res2d must be a pandas DataFrame")
        return df
    raise TypeError("pre_res must be a gseapy.prerank result object with .res2d OR a pandas DataFrame.")


def _find_col(df: pd.DataFrame, candidates: Sequence[str]) -> str:
    for c in candidates:
        if c in df.columns:
            return c
    raise KeyError(f"None of the candidate columns found: {list(candidates)}")


def _split_leading_edge(x) -> list[str]:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return []
    s = str(x).strip()
    if s == "" or s.lower() == "nan":
        return []
    # common separators: '/', ';', ',', whitespace
    for sep in [";", ",", "/", "|"]:
        s = s.replace(sep, " ")
    genes = [g for g in s.split() if g]
    return genes


def _normalize_term_selector(
    df: pd.DataFrame,
    *,
    term_idx=None,
    terms: Sequence[str] | None = None,
) -> pd.DataFrame:
    """
    Filter by either:
      - terms (explicit names), OR
      - term_idx (slice | int | Sequence[int]) using iloc positional indices.
    """
    if terms is not None and len(terms) > 0:
        term_col = _find_col(df, ["Term", "term", "pathway", "Pathway", "NAME", "name"])
        terms_set = set(map(str, terms))
        out = df[df[term_col].astype(str).isin(terms_set)].copy()
        if out.shape[0] == 0:
            raise ValueError("No rows matched the provided `terms`.")
        return out

    if term_idx is None:
        return df

    if isinstance(term_idx, int):
        out = df.iloc[[int(term_idx)]].copy()
    elif isinstance(term_idx, slice):
        out = df.iloc[term_idx].copy()
    elif isinstance(term_idx, (list, tuple, np.ndarray, pd.Index)):
        idx = [int(i) for i in list(term_idx)]
        n = df.shape[0]
        bad = [i for i in idx if i < 0 or i >= n]
        if bad:
            raise IndexError(f"term_idx contains out-of-range indices (0..{n-1}): {bad[:10]}")
        out = df.iloc[idx].copy()
    else:
        raise TypeError("term_idx must be None | int | slice | Sequence[int]")

    if out.shape[0] == 0:
        raise ValueError("term_idx selection produced an empty result table.")
    return out


def _leading_edge_sets(
    pre_res,
    *,
    term_idx=None,
    terms: Sequence[str] | None = None,
) -> tuple[list[str], dict[str, set[str]]]:
    """
    Returns:
      - term_names (ordered)
      - dict term -> set(leading-edge genes)
    """
    df = _get_res2d(pre_res)
    df = _normalize_term_selector(df, term_idx=term_idx, terms=terms)

    term_col = _find_col(df, ["Term", "term", "pathway", "Pathway", "NAME", "name"])
    le_col = _find_col(df, ["Lead_genes", "lead_genes", "ledge_genes", "ledge", "Lead_genes "])

    term_names: list[str] = []
    le_sets: dict[str, set[str]] = {}

    for _, row in df.iterrows():
        t = str(row[term_col])
        genes = _split_leading_edge(row[le_col])
        term_names.append(t)
        le_sets[t] = set(map(str, genes))

    term_names = [t for t in term_names if len(le_sets.get(t, set())) > 0]
    le_sets = {t: le_sets[t] for t in term_names}

    if len(term_names) == 0:
        raise ValueError("No leading-edge genes found for the selected terms/indices.")
    return term_names, le_sets


# ---------------------------------------------------------------------
# Core computation helpers
# ---------------------------------------------------------------------
def _jaccard_matrix(term_names: list[str], le_sets: dict[str, set[str]]) -> pd.DataFrame:
    n = len(term_names)
    J = np.zeros((n, n), dtype=float)
    for i, ti in enumerate(term_names):
        Ai = le_sets[ti]
        for j, tj in enumerate(term_names):
            Aj = le_sets[tj]
            inter = len(Ai & Aj)
            union = len(Ai | Aj)
            J[i, j] = (inter / union) if union > 0 else 0.0
    return pd.DataFrame(J, index=term_names, columns=term_names)


def _cohesion(dfJ: pd.DataFrame, labels: np.ndarray) -> pd.DataFrame:
    """
    Simple cluster cohesion metrics from similarity matrix.
    - within_mean: mean similarity within cluster (off-diagonal)
    - between_mean: mean similarity to outside cluster
    """
    terms = dfJ.index.to_list()
    out = []
    for c in np.unique(labels):
        idx = np.where(labels == c)[0]
        if idx.size <= 1:
            within = np.nan
        else:
            sub = dfJ.to_numpy()[np.ix_(idx, idx)]
            # exclude diagonal
            within = float((sub.sum() - np.trace(sub)) / (idx.size * (idx.size - 1)))
        # between
        other = np.where(labels != c)[0]
        if other.size == 0:
            between = np.nan
        else:
            subb = dfJ.to_numpy()[np.ix_(idx, other)]
            between = float(np.nanmean(subb)) if subb.size else np.nan
        out.append({"cluster": int(c), "n_terms": int(idx.size), "within_mean": within, "between_mean": between})
    return pd.DataFrame(out).sort_values(["n_terms", "within_mean"], ascending=[False, False]).reset_index(drop=True)


# ---------------------------------------------------------------------
# NEW 1) Pathway clustering (nodules) + metrics
# ---------------------------------------------------------------------
def leading_edge_pathway_clusters(
    pre_res,
    *,
    term_idx=None,
    terms: Sequence[str] | None = None,
    method: str = "average",
    threshold: float | None = None,
    n_clusters: int | None = None,
    min_shared_genes: int = 0,
) -> dict:
    """
    Cluster pathways into 'nodules' using leading-edge Jaccard similarity.

    You can specify:
      - threshold: cut dendrogram by distance (distance = 1 - Jaccard), OR
      - n_clusters: request a fixed number of clusters.

    Returns
    -------
    dict with:
      - term_names
      - le_sets
      - dfJ (Jaccard similarity)
      - clusters (pd.Series: pathway -> cluster_id)
      - metrics (cluster cohesion table)
    """
    if linkage is None or fcluster is None or squareform is None:
        raise ImportError("leading_edge_pathway_clusters requires scipy (cluster.hierarchy + distance).")

    term_names, le_sets = _leading_edge_sets(pre_res, term_idx=term_idx, terms=terms)
    dfJ = _jaccard_matrix(term_names, le_sets)

    if int(min_shared_genes) > 0:
        # de-noise: zero similarities if intersection too small
        for i, ti in enumerate(term_names):
            for j, tj in enumerate(term_names):
                if i == j:
                    continue
                if len(le_sets[ti] & le_sets[tj]) < int(min_shared_genes):
                    dfJ.iloc[i, j] = 0.0

    D = 1.0 - dfJ.to_numpy()
    # ensure valid condensed form
    d = squareform(D, checks=False)
    Z = linkage(d, method=method)

    if (threshold is None) == (n_clusters is None):
        raise ValueError("Provide exactly one of: threshold OR n_clusters.")

    if threshold is not None:
        # distance threshold (smaller threshold -> more clusters)
        labels = fcluster(Z, t=float(threshold), criterion="distance")
    else:
        labels = fcluster(Z, t=int(n_clusters), criterion="maxclust")

    clusters = pd.Series(labels, index=term_names, name="cluster").astype(int)
    metrics = _cohesion(dfJ, labels)

    return {
        "term_names": term_names,
        "le_sets": le_sets,
        "dfJ": dfJ,
        "clusters": clusters,
        "metrics": metrics,
        "linkage": Z,
    }


# ---------------------------------------------------------------------
# NEW 2) High-level leading-edge genes per pathway nodule
# ---------------------------------------------------------------------
def leading_edge_cluster_driver_genes(
    le_sets: dict[str, set[str]],
    clusters: pd.Series,
    *,
    top_n: int = 25,
    min_gene_freq_in_cluster: int = 2,
) -> pd.DataFrame:
    """
    Compute 'driver genes' for each pathway cluster (nodule).

    Scoring:
      - freq_in_cluster: in how many pathways (within the cluster) gene appears
      - coverage_in_cluster: freq / n_terms_in_cluster
      - freq_global: how many pathways overall contain the gene
      - tfidf_like: coverage_in_cluster * log( (1 + n_terms_total) / (1 + freq_global) )

    Returns
    -------
    DataFrame with columns:
      cluster, gene, freq_in_cluster, coverage_in_cluster, freq_global, tfidf_like
    """
    clusters = clusters.astype(int)
    term_to_cluster = clusters.to_dict()

    # global freq
    global_freq: dict[str, int] = {}
    for t, gs in le_sets.items():
        for g in gs:
            global_freq[g] = global_freq.get(g, 0) + 1
    n_terms_total = len(le_sets)

    # cluster sizes
    cl_terms: dict[int, list[str]] = {}
    for t, c in term_to_cluster.items():
        cl_terms.setdefault(int(c), []).append(t)

    rows = []
    for c, terms_in_c in cl_terms.items():
        # cluster freq
        freq_c: dict[str, int] = {}
        for t in terms_in_c:
            for g in le_sets.get(t, set()):
                freq_c[g] = freq_c.get(g, 0) + 1

        n_terms_c = max(1, len(terms_in_c))
        for g, f in freq_c.items():
            if int(f) < int(min_gene_freq_in_cluster):
                continue
            cov = float(f) / float(n_terms_c)
            fg = int(global_freq.get(g, 0))
            tfidf = cov * float(np.log((1.0 + n_terms_total) / (1.0 + fg)))
            rows.append(
                {
                    "cluster": int(c),
                    "gene": str(g),
                    "freq_in_cluster": int(f),
                    "coverage_in_cluster": cov,
                    "freq_global": int(fg),
                    "tfidf_like": tfidf,
                    "n_terms_in_cluster": int(n_terms_c),
                }
            )

    df = pd.DataFrame(rows)
    if df.shape[0] == 0:
        return df

    # rank within cluster
    df = df.sort_values(["cluster", "tfidf_like", "freq_in_cluster", "gene"], ascending=[True, False, False, True])
    df = df.groupby("cluster", group_keys=False).head(int(top_n)).reset_index(drop=True)
    return df




# ---------------------------------------------------------------------
# Plot 1: Leading-edge expression heatmap (adata required)
# ---------------------------------------------------------------------
[docs] def gsea_leading_edge_heatmap( adata, pre_res, *, term_idx=None, terms: Sequence[str] | None = None, layer: str | None = "log1p_cpm", use: str = "samples", # "samples" | "group_mean" groupby: str | None = None, min_gene_freq: int = 1, max_genes: int | None = 200, z_score: str | None = "row", # "row" | None clip_z: float | None = 3.0, row_cluster: bool = True, col_cluster: bool = True, cmap: str = "vlag", figsize: tuple[float, float] | None = None, show_labels: bool = False, gene_label_fontsize: float = 7.0, label_fontsize: float = 9.0, dendrogram_ratio: tuple[float, float] = (0.06, 0.12), title: str | None = None, show_title: bool = True, save: str | Path | None = None, show: bool = True, ): """ Heatmap of expression for leading-edge genes from selected pathways. Supports pre_res object OR DataFrame. Enhancements: - dendrogram_ratio to shrink left dendrogram - label_fontsize controls axis labels - show_title toggles title """ set_style() if sns is None: raise ImportError("gsea_leading_edge_heatmap requires seaborn. Please install seaborn.") term_names, le_sets = _leading_edge_sets(pre_res, term_idx=term_idx, terms=terms) # gene frequency across pathways freq: dict[str, int] = {} for s in le_sets.values(): for g in s: freq[g] = freq.get(g, 0) + 1 genes = [g for g, f in freq.items() if f >= int(min_gene_freq)] if len(genes) == 0: raise ValueError(f"No genes pass min_gene_freq={min_gene_freq}.") genes = sorted(genes, key=lambda g: (-freq[g], g)) if max_genes is not None: genes = genes[: int(max_genes)] # keep only genes present genes = [g for g in genes if g in adata.var_names] if len(genes) == 0: raise ValueError("None of the leading-edge genes are present in adata.var_names.") X = adata.layers[layer] if (layer is not None and layer in getattr(adata, "layers", {})) else adata.X gidx = [adata.var_names.get_loc(g) for g in genes] M = X[:, gidx].toarray() if (sp is not None and sp.issparse(X)) else np.asarray(X[:, gidx], dtype=float) if use not in {"samples", "group_mean"}: raise ValueError("use must be 'samples' or 'group_mean'") if use == "group_mean": if groupby is None: raise ValueError("groupby must be provided when use='group_mean'") if groupby not in adata.obs.columns: raise KeyError(f"groupby='{groupby}' not found in adata.obs") s = adata.obs[groupby].astype("category") cats = list(s.cat.categories) out = np.zeros((len(cats), len(genes)), dtype=float) for i, c in enumerate(cats): mask = (s == c).to_numpy() out[i, :] = np.nan if mask.sum() == 0 else M[mask, :].mean(axis=0) df = pd.DataFrame(out, index=[str(c) for c in cats], columns=genes) else: df = pd.DataFrame(M, index=adata.obs_names.astype(str), columns=genes) if z_score == "row": arr = df.to_numpy(dtype=float) mu = np.nanmean(arr, axis=0, keepdims=True) sd = np.nanstd(arr, axis=0, keepdims=True) sd[sd == 0] = 1.0 arr = (arr - mu) / sd if clip_z is not None: arr = np.clip(arr, -float(clip_z), float(clip_z)) df = pd.DataFrame(arr, index=df.index, columns=df.columns) elif z_score is None: pass else: raise ValueError("z_score must be 'row' or None") if figsize is None: w = max(7.0, 0.10 * df.shape[1] + 4.5) h = max(5.0, 0.12 * df.shape[0] + 3.0) figsize = (w, h) g = sns.clustermap( df, cmap=cmap, row_cluster=bool(row_cluster), col_cluster=bool(col_cluster), xticklabels=bool(show_labels), yticklabels=True, figsize=figsize, dendrogram_ratio=dendrogram_ratio, cbar_kws={"label": "Z-score" if z_score == "row" else "Expression"}, ) ax = g.ax_heatmap if show_title: ax.set_title( title or f"Leading-edge expression heatmap (terms={len(term_names)}, genes={df.shape[1]})", pad=10, ) ax.set_xlabel("Leading-edge genes", fontsize=float(label_fontsize)) ax.set_ylabel("Samples" if use == "samples" else str(groupby), fontsize=float(label_fontsize)) for lab in ax.get_yticklabels(): lab.set_fontsize(float(label_fontsize)) if show_labels: _rotate_gene_labels(ax, fontsize=float(gene_label_fontsize)) g.fig.subplots_adjust(bottom=0.30) if save is not None: _savefig(g.fig, save) if show: plt.show() return df, g
# --------------------------------------------------------------------- # Plot 2: Jaccard similarity heatmap # ---------------------------------------------------------------------
[docs] def leading_edge_jaccard_heatmap( pre_res, *, term_idx=None, terms: Sequence[str] | None = None, min_shared_genes: int = 0, row_cluster: bool = True, col_cluster: bool = True, cmap: str = "viridis", vmin: float = 0.0, vmax: float = 1.0, figsize: tuple[float, float] | None = None, show_labels: bool = True, label_fontsize: float = 9.0, dendrogram_ratio: tuple[float, float] = (0.06, 0.12), title: str | None = None, show_title: bool = True, save: str | Path | None = None, show: bool = True, ): set_style() if sns is None: raise ImportError("leading_edge_jaccard_heatmap requires seaborn. Please install seaborn.") term_names, le_sets = _leading_edge_sets(pre_res, term_idx=term_idx, terms=terms) dfJ = _jaccard_matrix(term_names, le_sets) if int(min_shared_genes) > 0: for i, ti in enumerate(term_names): for j, tj in enumerate(term_names): if i == j: continue if len(le_sets[ti] & le_sets[tj]) < int(min_shared_genes): dfJ.iloc[i, j] = 0.0 n = len(term_names) if figsize is None: s = max(6.0, 0.35 * n + 3.0) figsize = (s, s) g = sns.clustermap( dfJ, cmap=cmap, vmin=float(vmin), vmax=float(vmax), row_cluster=bool(row_cluster), col_cluster=bool(col_cluster), xticklabels=bool(show_labels), yticklabels=bool(show_labels), figsize=figsize, dendrogram_ratio=dendrogram_ratio, ) ax = g.ax_heatmap if show_title: ax.set_title(title or "Leading-edge Jaccard similarity (pathway × pathway)", pad=10) if show_labels: for lab in ax.get_xticklabels(): lab.set_rotation(90) lab.set_ha("center") lab.set_va("top") lab.set_fontsize(float(label_fontsize)) for lab in ax.get_yticklabels(): lab.set_fontsize(float(label_fontsize)) g.fig.subplots_adjust(bottom=0.30) if save is not None: _savefig(g.fig, save) if show: plt.show() return dfJ, g
# --------------------------------------------------------------------- # Plot 3: Pathway × gene overlap matrix (binary) # ---------------------------------------------------------------------
[docs] def leading_edge_overlap_matrix( pre_res, *, term_idx=None, terms: Sequence[str] | None = None, min_gene_freq: int = 2, sort_genes_by: str = "freq", # "freq" | "alpha" row_cluster: bool = True, col_cluster: bool = False, cmap: str = "Greys", figsize: tuple[float, float] | None = None, show_gene_labels: bool = True, gene_label_fontsize: float = 8.0, show_term_labels: bool = True, label_fontsize: float = 9.0, dendrogram_ratio: tuple[float, float] = (0.06, 0.12), title: str | None = None, show_title: bool = True, grid_every: int = 5, grid_color: str = "0.90", save: str | Path | None = None, show: bool = True, ): """ Pathway × gene binary matrix for leading-edge membership. Enhancements: - grid lines every N rows/cols for readability - label_fontsize controls pathway labels too - dendrogram_ratio shrinks left dendrogram - when returning df, returns it *in the clustered order* (rows/cols) """ set_style() if sns is None: raise ImportError("leading_edge_overlap_matrix requires seaborn. Please install seaborn.") if int(min_gene_freq) < 1: raise ValueError("min_gene_freq must be >= 1") term_names, le_sets = _leading_edge_sets(pre_res, term_idx=term_idx, terms=terms) all_genes = sorted(set().union(*le_sets.values())) if len(all_genes) == 0: raise ValueError("No genes in leading-edge sets.") mat = np.zeros((len(term_names), len(all_genes)), dtype=int) for i, t in enumerate(term_names): s = le_sets[t] for j, g in enumerate(all_genes): mat[i, j] = 1 if g in s else 0 df = pd.DataFrame(mat, index=term_names, columns=all_genes) gene_freq = df.sum(axis=0).astype(int) keep = gene_freq[gene_freq >= int(min_gene_freq)].index.tolist() df = df.loc[:, keep] if df.shape[1] == 0: raise ValueError(f"No genes pass min_gene_freq={min_gene_freq}.") if sort_genes_by == "freq": df = df.loc[:, df.sum(axis=0).sort_values(ascending=False).index] elif sort_genes_by == "alpha": df = df.loc[:, sorted(df.columns)] else: raise ValueError("sort_genes_by must be 'freq' or 'alpha'") if figsize is None: w = max(7.0, 0.18 * df.shape[1] + 4.5) h = max(4.5, 0.22 * df.shape[0] + 2.5) figsize = (w, h) # ensure 0 is white even if user uses a different cmap cm = mpl.cm.get_cmap(cmap) if hasattr(cm, "with_extremes"): cm = cm.with_extremes(under="white") # but easiest: vmin=0 and df is 0/1 g = sns.clustermap( df, cmap=cm, vmin=0, vmax=1, row_cluster=bool(row_cluster), col_cluster=bool(col_cluster), linewidths=0.0, xticklabels=bool(show_gene_labels), yticklabels=bool(show_term_labels), figsize=figsize, dendrogram_ratio=dendrogram_ratio, cbar_pos=None, ) ax = g.ax_heatmap ax.set_xlabel(f"Leading-edge genes (kept={df.shape[1]}, freq≥{min_gene_freq})", fontsize=float(label_fontsize)) ax.set_ylabel("Pathways", fontsize=float(label_fontsize)) if show_title: ax.set_title(title or "Leading-edge overlap (pathway × gene)", pad=10) # grid every N if int(grid_every) > 0: nrows, ncols = df.shape for r in range(0, nrows + 1, int(grid_every)): ax.axhline(r, color=grid_color, lw=0.8, zorder=10) for c in range(0, ncols + 1, int(grid_every)): ax.axvline(c, color=grid_color, lw=0.8, zorder=10) # apply fonts for lab in ax.get_yticklabels(): lab.set_fontsize(float(label_fontsize)) if show_gene_labels: _rotate_gene_labels(ax, fontsize=float(gene_label_fontsize)) g.fig.subplots_adjust(bottom=0.30) # IMPORTANT: return df in clustered order (rows + cols) row_order = g.dendrogram_row.reordered_ind if g.dendrogram_row is not None else list(range(df.shape[0])) col_order = g.dendrogram_col.reordered_ind if g.dendrogram_col is not None else list(range(df.shape[1])) df_ord = df.iloc[row_order, :].iloc[:, col_order].copy() if save is not None: _savefig(g.fig, save) if show: plt.show() return df_ord, g
# --------------------------------------------------------------------- # Plot 4: Pathway cluster bubbles # --------------------------------------------------------------------- def leading_edge_cluster_bubbles( res: dict, *, min_overlap_genes: int = 1, layout: Literal["mds", "spring", "pcoa"] = "pcoa", figsize: tuple[float, float] = (8.0, 7.0), bubble_alpha: float = 0.35, edge_alpha: float = 0.40, edge_width_scale: float = 6.0, radius_scale: float = 0.10, radius_mode: Literal["sqrt", "log"] = "log", dist_scale: float = 4.0, normalize_xy: bool = True, # NEW: can disable if you want raw embedding scale show_labels: bool = True, label_fontsize: float = 10.0, label_style: Literal["id", "size"] = "id", show_edges: bool = True, seed: int = 0, title: str | None = "Pathway clusters bubble map", repel: bool = True, repel_iters: int = 200, repel_strength: float = 0.8, repel_padding: float = 0.02, return_metrics: bool = False, # NEW: optionally return embedding metrics as 8th object ): """ Bubble map for pathway clusters. - Bubble size ~ #unique leading-edge genes in the cluster (union of member pathways). - Bubble distance ~ attempts to match D = 1 - Jaccard between cluster unions. - Optional edges connect clusters sharing >= min_overlap_genes genes (width ~ shared genes). - dist_scale spreads the embedding; radius_mode/radius_scale control bubble sizes. - repel performs a simple overlap-reduction pass (no extra deps). Returns ------- (fig, ax, df_xy, dfJc, dfOc, cluster_gene_sets, cluster_terms) If return_metrics=True, adds an extra 8th return dict with: {"stress": ..., "dist_corr": ...} """ # ---- extract ---- if "clusters" not in res or "le_sets" not in res: raise ValueError("res must contain 'clusters' and 'le_sets' from leading_edge_pathway_clusters().") clusters: pd.Series = res["clusters"] le_sets: dict = res["le_sets"] # ---- build cluster unions: cluster_id -> set(genes) ---- cluster_ids = sorted(pd.unique(clusters.astype(int))) cluster_gene_sets: dict[int, set[str]] = {} cluster_terms: dict[int, list[str]] = {} for term, cid in clusters.items(): cid = int(cid) term = str(term) cluster_terms.setdefault(cid, []).append(term) cluster_gene_sets.setdefault(cid, set()).update(set(map(str, le_sets.get(term, set())))) cluster_ids = [cid for cid in cluster_ids if len(cluster_gene_sets.get(cid, set())) > 0] if len(cluster_ids) < 2: raise ValueError("Need at least 2 non-empty clusters for bubble plot.") genesets = [cluster_gene_sets[cid] for cid in cluster_ids] n = len(cluster_ids) # ---- overlap + cluster-level Jaccard ---- O = np.zeros((n, n), dtype=int) J = np.zeros((n, n), dtype=float) for i in range(n): Ai = genesets[i] for j in range(n): Aj = genesets[j] inter = len(Ai & Aj) union = len(Ai | Aj) O[i, j] = inter J[i, j] = (inter / union) if union > 0 else 0.0 D = 1.0 - J np.fill_diagonal(D, 0.0) # ---- helpers for embedding quality ---- stress = np.nan dist_corr = np.nan def _pairwise_euclid(XY: np.ndarray) -> np.ndarray: # returns square distance matrix XY = np.asarray(XY, dtype=float) n0 = XY.shape[0] DD = np.zeros((n0, n0), dtype=float) for i in range(n0): for j in range(i + 1, n0): d = float(np.hypot(XY[i, 0] - XY[j, 0], XY[i, 1] - XY[j, 1])) DD[i, j] = d DD[j, i] = d return DD def _dist_corr(Dtrue: np.ndarray, Dplot: np.ndarray) -> float: m = np.triu_indices_from(Dtrue, k=1) a = Dtrue[m].astype(float) b = Dplot[m].astype(float) ok = np.isfinite(a) & np.isfinite(b) if ok.sum() < 3: return float("nan") a = a[ok] b = b[ok] if np.std(a) == 0 or np.std(b) == 0: return float("nan") return float(np.corrcoef(a, b)[0, 1]) # ---- layout ---- dv = float(np.nanstd(D[np.isfinite(D)])) if np.isfinite(D).any() else 0.0 if dv < 1e-8: # all distances ~ equal -> place on circle ang = np.linspace(0, 2 * np.pi, n, endpoint=False) XY = np.c_[np.cos(ang), np.sin(ang)] else: if layout == "pcoa": # Classical MDS / PCoA D2 = D**2 H = np.eye(n) - np.ones((n, n)) / n B = -0.5 * H @ D2 @ H evals, evecs = np.linalg.eigh(B) order = np.argsort(evals)[::-1] evals = evals[order] evecs = evecs[:, order] evals = np.maximum(evals[:2], 0.0) XY = evecs[:, :2] * np.sqrt(evals) elif layout == "mds": if MDS is None: raise ImportError("layout='mds' requires scikit-learn. Install with: pip install scikit-learn") mds = MDS( n_components=2, dissimilarity="precomputed", metric=True, random_state=int(seed), n_init=20, max_iter=3000, eps=1e-12, ) XY = mds.fit_transform(D) try: stress = float(mds.stress_) except Exception: stress = np.nan elif layout == "spring": try: import networkx as nx except Exception as e: raise ImportError(f"layout='spring' requires networkx. Install with: pip install networkx. ({e})") G = nx.Graph() for cid in cluster_ids: G.add_node(cid) for i in range(n): for j in range(i + 1, n): w = float(J[i, j]) if w > 0: G.add_edge(cluster_ids[i], cluster_ids[j], weight=w) pos = nx.spring_layout(G, seed=int(seed), weight="weight") XY = np.array([[pos[c][0], pos[c][1]] for c in cluster_ids], dtype=float) else: raise ValueError("layout must be 'pcoa' or 'mds' or 'spring'") # ---- center/scale coordinates ---- x = XY[:, 0].astype(float) y = XY[:, 1].astype(float) # center always x -= np.nanmean(x) y -= np.nanmean(y) if normalize_xy: s = float(np.nanstd(np.r_[x, y])) if not np.isfinite(s) or s == 0: s = 1.0 x /= s y /= s # expand distances x *= float(dist_scale) y *= float(dist_scale) # ---- radii ---- n_genes = np.array([len(cluster_gene_sets[cid]) for cid in cluster_ids], dtype=float) if radius_mode == "sqrt": r = np.sqrt(np.maximum(n_genes, 1.0)) elif radius_mode == "log": r = np.log1p(np.maximum(n_genes, 1.0)) else: raise ValueError("radius_mode must be 'sqrt' or 'log'") r = float(radius_scale) * r # ---- repel (push circles apart) ---- if repel and n >= 2: rng = np.random.default_rng(int(seed)) x = x + rng.normal(0, 1e-4, size=n) y = y + rng.normal(0, 1e-4, size=n) for _ in range(int(repel_iters)): moved = False for i in range(n): for j in range(i + 1, n): dx = x[j] - x[i] dy = y[j] - y[i] dist = float(np.hypot(dx, dy)) target = float(r[i] + r[j] + repel_padding) if dist < 1e-12: ang = float(rng.uniform(0, 2 * np.pi)) dx, dy = np.cos(ang), np.sin(ang) dist = 1e-6 if dist < target: push = (target - dist) / dist step = float(repel_strength) * push * 0.5 x[i] -= dx * step y[i] -= dy * step x[j] += dx * step y[j] += dy * step moved = True if not moved: break # ---- compute fidelity metric ---- Dplot = _pairwise_euclid(np.c_[x, y]) dist_corr = _dist_corr(D, Dplot) # ---- plot ---- fig, ax = plt.subplots(figsize=figsize) ax.set_aspect("equal", adjustable="datalim") if show_edges: maxO = max(1, int(O.max())) for i in range(n): for j in range(i + 1, n): shared = int(O[i, j]) if shared < int(min_overlap_genes): continue lw = float(edge_width_scale) * (shared / maxO) ax.plot( [x[i], x[j]], [y[i], y[j]], linewidth=lw, alpha=float(edge_alpha), color="0.40", zorder=1 ) for i, cid in enumerate(cluster_ids): circ = plt.Circle( (x[i], y[i]), float(r[i]), alpha=float(bubble_alpha), ec="0.25", lw=1.0, zorder=2 ) ax.add_patch(circ) if show_labels: for i, cid in enumerate(cluster_ids): if label_style == "size": lab = f"C{cid}\n{int(n_genes[i])} genes" else: lab = f"C{cid}" ax.text( x[i], y[i], lab, ha="center", va="center", fontsize=float(label_fontsize), zorder=3 ) pad = float(np.max(r)) * 1.6 if np.isfinite(r).any() else 0.5 ax.set_xlim(float(np.min(x) - pad), float(np.max(x) + pad)) ax.set_ylim(float(np.min(y) - pad), float(np.max(y) + pad)) ax.set_xticks([]) ax.set_yticks([]) if title: # attach fidelity info lightly (optional) extra = "" if np.isfinite(dist_corr): extra += f" (dist corr={dist_corr:.2f})" ax.set_title(str(title) + extra, pad=12) # ---- tables ---- df_xy = pd.DataFrame( { "cluster": cluster_ids, "x": x, "y": y, "n_le_genes": n_genes, "radius": r, "n_terms": [len(cluster_terms[cid]) for cid in cluster_ids], } ).set_index("cluster") dfJc = pd.DataFrame(J, index=cluster_ids, columns=cluster_ids) dfOc = pd.DataFrame(O, index=cluster_ids, columns=cluster_ids) if return_metrics: metrics = {"stress": float(stress) if np.isfinite(stress) else np.nan, "dist_corr": float(dist_corr) if np.isfinite(dist_corr) else np.nan} return fig, ax, df_xy, dfJc, dfOc, cluster_gene_sets, cluster_terms, metrics return fig, ax, df_xy, dfJc, dfOc, cluster_gene_sets, cluster_terms # --------------------------------------------------------------------- # Export for Cytoscape # --------------------------------------------------------------------- def export_leading_edge_clusters_cytoscape( res: dict, *, min_overlap_genes: int = 1, min_jaccard: float | None = None, include_genes: bool = True, ) -> dict[str, pd.DataFrame]: """ Export pathway clusters and their overlaps for Cytoscape. Returns a dict of DataFrames: - nodes_clusters: cluster node table - edges_clusters: cluster-cluster edge table (overlap + jaccard) - (optional) nodes_genes: gene node table - (optional) edges_cluster_gene: bipartite edges cluster->gene Notes ----- - Nodes are labeled "C{cluster_id}" for clusters, and gene symbols for genes. """ if "clusters" not in res or "le_sets" not in res: raise ValueError("res must contain 'clusters' and 'le_sets'.") clusters: pd.Series = res["clusters"].astype(int) le_sets: dict = res["le_sets"] # build unions cluster_ids = sorted(pd.unique(clusters)) cluster_gene_sets: dict[int, set[str]] = {int(cid): set() for cid in cluster_ids} cluster_terms: dict[int, list[str]] = {int(cid): [] for cid in cluster_ids} for term, cid in clusters.items(): cid = int(cid) cluster_terms[cid].append(str(term)) cluster_gene_sets[cid].update(set(map(str, le_sets.get(term, set())))) # drop empties cluster_ids = [cid for cid in cluster_ids if len(cluster_gene_sets[cid]) > 0] n = len(cluster_ids) if n == 0: raise ValueError("No non-empty clusters found.") # nodes: clusters nodes_clusters = pd.DataFrame( { "id": [f"C{cid}" for cid in cluster_ids], "cluster": cluster_ids, "n_terms": [len(cluster_terms[cid]) for cid in cluster_ids], "n_le_genes": [len(cluster_gene_sets[cid]) for cid in cluster_ids], "terms": [";".join(cluster_terms[cid]) for cid in cluster_ids], } ) # edges: cluster-cluster rows = [] for i in range(n): ci = cluster_ids[i] Ai = cluster_gene_sets[ci] for j in range(i + 1, n): cj = cluster_ids[j] Aj = cluster_gene_sets[cj] inter = len(Ai & Aj) union = len(Ai | Aj) jac = (inter / union) if union > 0 else 0.0 if inter < int(min_overlap_genes): continue if (min_jaccard is not None) and (jac < float(min_jaccard)): continue rows.append( { "source": f"C{ci}", "target": f"C{cj}", "shared_genes": int(inter), "jaccard": float(jac), "union_genes": int(union), } ) edges_clusters = pd.DataFrame(rows) out = {"nodes_clusters": nodes_clusters, "edges_clusters": edges_clusters} # optional bipartite network cluster-gene if include_genes: all_genes = sorted(set().union(*[cluster_gene_sets[c] for c in cluster_ids])) nodes_genes = pd.DataFrame({"id": all_genes, "type": "gene"}) nodes_clusters2 = nodes_clusters.copy() nodes_clusters2["type"] = "cluster" # edges cluster->gene e2 = [] for cid in cluster_ids: for g in cluster_gene_sets[cid]: e2.append({"source": f"C{cid}", "target": str(g), "type": "cluster_gene"}) edges_cluster_gene = pd.DataFrame(e2) out["nodes_all"] = pd.concat([nodes_clusters2[["id", "type"]], nodes_genes], axis=0).reset_index(drop=True) out["edges_cluster_gene"] = edges_cluster_gene return out