Source code for bullkpy.tl.correlations

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, Iterable, Sequence

import numpy as np
import pandas as pd
import scipy.sparse as sp

from scipy.stats import t as t_dist
from scipy.stats import spearmanr

import anndata as ad

try:
    import seaborn as sns  # optional
except Exception:  # pragma: no cover
    sns = None

import matplotlib.pyplot as plt

from ..logging import info, warn


Method = Literal["pearson", "spearman"]
BatchMode = Literal["none", "residualize", "within_batch", "meta"]


# ---------------------------
# Core helpers
# ---------------------------

def _as_list(x: str | Sequence[str] | None) -> list[str]:
    if x is None:
        return []
    if isinstance(x, str):
        return [x]
    return list(x)


def _get_layer_matrix(adata: ad.AnnData, layer: str | None) -> np.ndarray:
    X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X
    if sp.issparse(X):
        X = X.toarray()
    return np.asarray(X, dtype=float)


def _get_gene_vector(adata: ad.AnnData, gene: str, *, layer: str | None) -> np.ndarray:
    if gene not in adata.var_names:
        raise KeyError(f"Gene '{gene}' not in adata.var_names")
    X = _get_layer_matrix(adata, layer)
    j = adata.var_names.get_loc(gene)
    return np.asarray(X[:, j], dtype=float).ravel()


def _numeric_obs_df(adata: ad.AnnData, *, obs_keys: list[str] | None = None) -> pd.DataFrame:
    if obs_keys is None:
        # auto-pick numeric columns
        df = adata.obs.copy()
        num_cols = []
        for c in df.columns:
            s = df[c]
            # pandas numeric?
            if pd.api.types.is_numeric_dtype(s):
                num_cols.append(c)
        return df[num_cols].copy()
    else:
        missing = [k for k in obs_keys if k not in adata.obs.columns]
        if missing:
            raise KeyError(f"Obs keys not found: {missing}")
        df = adata.obs[obs_keys].copy()
        # ensure numeric
        for c in df.columns:
            if not pd.api.types.is_numeric_dtype(df[c]):
                raise TypeError(f"Obs column '{c}' is not numeric (dtype={df[c].dtype})")
        return df


def _bh_fdr(p: np.ndarray) -> np.ndarray:
    """Benjamini–Hochberg FDR correction."""
    p = np.asarray(p, dtype=float)
    q = np.full_like(p, np.nan, dtype=float)

    m = np.isfinite(p)
    if not np.any(m):
        return q

    pv = p[m]
    n = pv.size
    order = np.argsort(pv)
    ranks = np.empty(n, dtype=int)
    ranks[order] = np.arange(1, n + 1)

    qv = pv * (n / ranks)
    # monotone
    qv_sorted = qv[order]
    qv_sorted = np.minimum.accumulate(qv_sorted[::-1])[::-1]
    qv[order] = np.clip(qv_sorted, 0.0, 1.0)

    q[m] = qv
    return q


def _corr_pearson(x: np.ndarray, y: np.ndarray) -> tuple[float, float, int]:
    x = np.asarray(x, dtype=float).ravel()
    y = np.asarray(y, dtype=float).ravel()
    m = np.isfinite(x) & np.isfinite(y)
    n = int(m.sum())
    if n < 3:
        return np.nan, np.nan, n

    xa, ya = x[m], y[m]
    # centered
    xa = xa - xa.mean()
    ya = ya - ya.mean()

    denom = np.sqrt(np.sum(xa * xa) * np.sum(ya * ya))
    if denom == 0:
        return np.nan, np.nan, n

    r = float(np.sum(xa * ya) / denom)

    # p-value from t-stat
    df = n - 2
    if not np.isfinite(r) or df <= 0 or abs(r) >= 1:
        p = 0.0 if abs(r) == 1 else np.nan
    else:
        t = r * np.sqrt(df / (1.0 - r * r))
        p = float(2.0 * t_dist.sf(np.abs(t), df))
    return r, p, n


def _corr_spearman(x: np.ndarray, y: np.ndarray) -> tuple[float, float, int]:
    x = np.asarray(x, dtype=float).ravel()
    y = np.asarray(y, dtype=float).ravel()
    m = np.isfinite(x) & np.isfinite(y)
    n = int(m.sum())
    if n < 3:
        return np.nan, np.nan, n
    r, p = spearmanr(x[m], y[m], nan_policy="omit")
    return float(r), float(p), n


def _corr(x: np.ndarray, y: np.ndarray, *, method: Method) -> tuple[float, float, int]:
    if method == "pearson":
        return _corr_pearson(x, y)
    elif method == "spearman":
        return _corr_spearman(x, y)
    else:
        raise ValueError(f"Unknown method: {method}")


def _design_matrix_from_obs(
    adata: ad.AnnData,
    covariates: list[str],
    *,
    add_intercept: bool = True,
) -> np.ndarray:
    """
    Build a design matrix from obs columns (numeric + categorical one-hot).
    """
    if len(covariates) == 0:
        return np.ones((adata.n_obs, 1), dtype=float) if add_intercept else np.empty((adata.n_obs, 0), dtype=float)

    parts = []
    for c in covariates:
        if c not in adata.obs.columns:
            raise KeyError(f"Covariate '{c}' not in adata.obs")
        s = adata.obs[c]

        if pd.api.types.is_numeric_dtype(s):
            v = pd.to_numeric(s, errors="coerce").to_numpy(dtype=float).reshape(-1, 1)
            parts.append(v)
        else:
            # categorical / object → one-hot
            cat = pd.Categorical(s.astype(str))
            dummies = pd.get_dummies(cat, drop_first=True)  # avoid collinearity
            parts.append(dummies.to_numpy(dtype=float))

    X = np.concatenate(parts, axis=1) if len(parts) else np.empty((adata.n_obs, 0), dtype=float)
    if add_intercept:
        X = np.concatenate([np.ones((adata.n_obs, 1), dtype=float), X], axis=1)
    return X


def _residualize(y: np.ndarray, X: np.ndarray) -> np.ndarray:
    """
    Residualize y with respect to design matrix X via least squares.
    Handles NaNs by fitting on finite rows and returning NaNs elsewhere.
    """
    y = np.asarray(y, dtype=float).ravel()
    if X.size == 0:
        return y.copy()

    m = np.isfinite(y) & np.all(np.isfinite(X), axis=1)
    out = np.full_like(y, np.nan, dtype=float)
    if m.sum() < X.shape[1] + 2:
        return out

    Xm = X[m, :]
    ym = y[m]
    beta, *_ = np.linalg.lstsq(Xm, ym, rcond=None)
    out[m] = ym - Xm @ beta
    return out


def _batch_groups(adata: ad.AnnData, batch_key: str) -> pd.Series:
    if batch_key not in adata.obs.columns:
        raise KeyError(f"batch_key='{batch_key}' not in adata.obs")
    return adata.obs[batch_key].astype("category")


def _within_batch_zscore(v: np.ndarray, batch: pd.Series) -> np.ndarray:
    v = np.asarray(v, dtype=float).ravel()
    out = np.full_like(v, np.nan, dtype=float)
    for b in batch.cat.categories:
        m = (batch == b).to_numpy()
        x = v[m]
        mm = np.isfinite(x)
        if mm.sum() < 2:
            continue
        mu = np.nanmean(x)
        sd = np.nanstd(x, ddof=1)
        if not np.isfinite(sd) or sd == 0:
            continue
        out[m] = (x - mu) / sd
    return out


def _fisher_meta_cor(r_list: list[float], n_list: list[int]) -> tuple[float, float]:
    """
    Combine correlations via Fisher z (weighted by n-3).
    Returns (r_meta, z_se) where z_se is SE in Fisher-z space.
    """
    rs = np.asarray(r_list, dtype=float)
    ns = np.asarray(n_list, dtype=float)

    m = np.isfinite(rs) & np.isfinite(ns) & (ns >= 4) & (np.abs(rs) < 1)
    if not np.any(m):
        return np.nan, np.nan

    rs = rs[m]
    ns = ns[m]
    w = ns - 3.0
    z = np.arctanh(rs)
    zbar = np.sum(w * z) / np.sum(w)
    se = np.sqrt(1.0 / np.sum(w))
    r_meta = float(np.tanh(zbar))
    return r_meta, float(se)


def _corr_batch_aware(
    x: np.ndarray,
    y: np.ndarray,
    *,
    adata: ad.AnnData,
    method: Method,
    batch_key: str | None,
    batch_mode: BatchMode,
    covariates: list[str] | None = None,
) -> tuple[float, float, int]:
    """
    Compute correlation under batch-aware mode.

    - none: standard correlation
    - residualize: regress out batch_key (+covariates) from x and y, then correlate residuals
    - within_batch: z-score x and y within each batch, then correlate pooled
    - meta: correlate within each batch, then Fisher-z combine, approximate p-value from z/se
    """
    covariates = covariates or []

    if batch_mode == "none" or batch_key is None:
        return _corr(x, y, method=method)

    batch = _batch_groups(adata, batch_key)

    if batch_mode == "residualize":
        X = _design_matrix_from_obs(adata, covariates=[batch_key] + covariates, add_intercept=True)
        xr = _residualize(x, X)
        yr = _residualize(y, X)
        return _corr(xr, yr, method=method)

    if batch_mode == "within_batch":
        xz = _within_batch_zscore(x, batch)
        yz = _within_batch_zscore(y, batch)
        return _corr(xz, yz, method=method)

    if batch_mode == "meta":
        r_list: list[float] = []
        p_list: list[float] = []
        n_list: list[int] = []
        for b in batch.cat.categories:
            m = (batch == b).to_numpy()
            rb, pb, nb = _corr(np.asarray(x)[m], np.asarray(y)[m], method=method)
            r_list.append(rb)
            p_list.append(pb)
            n_list.append(nb)

        r_meta, se = _fisher_meta_cor(r_list, n_list)
        if not np.isfinite(r_meta) or not np.isfinite(se) or se <= 0:
            return np.nan, np.nan, int(np.nansum(n_list))

        # z-stat in Fisher space
        z = np.arctanh(r_meta) / se
        p = float(2.0 * (1.0 - 0.5 * (1.0 + np.math.erf(np.abs(z) / np.sqrt(2.0)))))  # 2*sf(|z|) approx
        return r_meta, p, int(np.nansum(n_list))

    raise ValueError(f"Unknown batch_mode='{batch_mode}'")


# ---------------------------
# A) Strongest gene–gene pairs
# ---------------------------

[docs] def top_gene_gene_correlations( adata: ad.AnnData, *, genes: list[str] | None = None, layer: str | None = "log1p_cpm", method: Method = "pearson", top_n: int = 200, min_abs_r: float | None = None, use_abs: bool = True, batch_key: str | None = None, batch_mode: BatchMode = "none", covariates: list[str] | None = None, ) -> pd.DataFrame: """ Find strongest gene–gene correlations among a gene set (or all genes if provided). Returns long table with (gene1, gene2, r, pval, qval, n). """ X = _get_layer_matrix(adata, layer) if genes is None: raise ValueError("For safety, provide genes=... for top_gene_gene_correlations (all genes is too big).") missing = [g for g in genes if g not in adata.var_names] if missing: raise KeyError(f"Genes not found: {missing[:10]}") idx = [adata.var_names.get_loc(g) for g in genes] M = np.asarray(X[:, idx], dtype=float) # n_obs x n_genes names = [adata.var_names[i] for i in idx] rows = [] pvals = [] # O(G^2) loop; OK for moderate gene panels G = len(names) for i in range(G): xi = M[:, i] for j in range(i + 1, G): yj = M[:, j] r, p, n = _corr_batch_aware( xi, yj, adata=adata, method=method, batch_key=batch_key, batch_mode=batch_mode, covariates=covariates, ) if min_abs_r is not None and np.isfinite(r) and abs(r) < float(min_abs_r): continue rows.append((names[i], names[j], r, p, n)) pvals.append(p) df = pd.DataFrame(rows, columns=["gene1", "gene2", "r", "pval", "n"]) df["qval"] = _bh_fdr(df["pval"].to_numpy()) score = np.abs(df["r"].to_numpy()) if use_abs else df["r"].to_numpy() df = df[np.isfinite(score)].copy() if df.empty: return df k = min(int(top_n), df.shape[0]) top_idx = np.argpartition(score[np.isfinite(score)], -k)[-k:] # safer: just sort df = df.sort_values("r", key=lambda s: np.abs(s), ascending=False).head(k).reset_index(drop=True) df["method"] = method df["batch_key"] = batch_key df["batch_mode"] = batch_mode return df
# --------------------------- # D) One gene vs genes # ---------------------------
[docs] def gene_gene_correlations( adata: ad.AnnData, *, gene: str, genes: list[str] | None = None, layer: str | None = "log1p_cpm", method: Method = "pearson", top_n: int = 50, min_abs_r: float | None = None, use_abs: bool = True, batch_key: str | None = None, batch_mode: BatchMode = "none", covariates: list[str] | None = None, ) -> pd.DataFrame: """ Correlate one gene with all other genes or a subset. Returns table: (gene, r, pval, qval, n). """ x0 = _get_gene_vector(adata, gene, layer=layer) X = _get_layer_matrix(adata, layer) if genes is None: idx = [i for i in range(adata.n_vars) if adata.var_names[i] != gene] names = [adata.var_names[i] for i in idx] else: missing = [g for g in genes if g not in adata.var_names] if missing: raise KeyError(f"Genes not found: {missing[:10]}") idx = [adata.var_names.get_loc(g) for g in genes if g != gene] names = [adata.var_names[i] for i in idx] rows = [] for i, nm in zip(idx, names): xi = X[:, i] r, p, n = _corr_batch_aware( x0, xi, adata=adata, method=method, batch_key=batch_key, batch_mode=batch_mode, covariates=covariates, ) rows.append((nm, r, p, n)) df = pd.DataFrame(rows, columns=["gene", "r", "pval", "n"]) df["qval"] = _bh_fdr(df["pval"].to_numpy()) if min_abs_r is not None: df = df[np.abs(df["r"]) >= float(min_abs_r)] score = np.abs(df["r"].to_numpy()) if use_abs else df["r"].to_numpy() df = df[np.isfinite(score)].copy() if df.empty: return df df = df.sort_values("r", key=lambda s: np.abs(s), ascending=False).head(int(top_n)).reset_index(drop=True) df.insert(0, "query_gene", gene) df["method"] = method df["batch_key"] = batch_key df["batch_mode"] = batch_mode return df
# --------------------------- # B) Gene vs numeric obs # ---------------------------
[docs] def top_gene_obs_correlations( adata: ad.AnnData, *, gene: str | list[str], obs: str | list[str] | None = None, obs_keys: list[str] | None = None, layer: str | None = "log1p_cpm", method: Method = "pearson", top_n: int = 50, min_abs_r: float | None = None, use_abs: bool = True, batch_key: str | None = None, batch_mode: BatchMode = "none", covariates: list[str] | None = None, ) -> pd.DataFrame: """ Correlate one or more genes with numeric obs columns. - obs: restrict to these obs columns - obs_keys: alternative explicit numeric obs selection """ genes = _as_list(gene) for g in genes: if g not in adata.var_names: raise KeyError(f"Gene '{g}' not in adata.var_names") obs_df = _numeric_obs_df(adata, obs_keys=obs_keys) if obs is not None: obs_list = _as_list(obs) missing = [k for k in obs_list if k not in obs_df.columns] if missing: raise KeyError(f"Obs keys not found or not numeric: {missing}") obs_df = obs_df[obs_list].copy() rows = [] for g in genes: xg = _get_gene_vector(adata, g, layer=layer) for col in obs_df.columns: y = obs_df[col].to_numpy(dtype=float) r, p, n = _corr_batch_aware( xg, y, adata=adata, method=method, batch_key=batch_key, batch_mode=batch_mode, covariates=covariates, ) if min_abs_r is not None and np.isfinite(r) and abs(r) < float(min_abs_r): continue rows.append((g, col, r, p, n)) df = pd.DataFrame(rows, columns=["gene", "obs", "r", "pval", "n"]) df["qval"] = _bh_fdr(df["pval"].to_numpy()) score = np.abs(df["r"].to_numpy()) if use_abs else df["r"].to_numpy() df = df[np.isfinite(score)].copy() if df.empty: return df df = df.sort_values("r", key=lambda s: np.abs(s), ascending=False).head(int(top_n)).reset_index(drop=True) df["method"] = method df["batch_key"] = batch_key df["batch_mode"] = batch_mode return df
# --------------------------- # C) Numeric obs × numeric obs # ---------------------------
[docs] def top_obs_obs_correlations( adata: ad.AnnData, *, focus: str | list[str], against: list[str] | None = None, obs_keys: list[str] | None = None, method: Method = "pearson", top_n: int = 100, min_abs_r: float | None = None, use_abs: bool = True, batch_key: str | None = None, batch_mode: BatchMode = "none", covariates: list[str] | None = None, ) -> pd.DataFrame: """ Strongest correlations between numeric obs columns. If against is None: compare focus vs all other numeric obs. Returns long table of pairs. """ obs_df = _numeric_obs_df(adata, obs_keys=obs_keys) focus_l = _as_list(focus) for k in focus_l: if k not in obs_df.columns: raise KeyError(f"focus '{k}' not numeric / not found in obs") if against is None: against_l = [c for c in obs_df.columns if c not in focus_l] else: against_l = _as_list(against) missing = [k for k in against_l if k not in obs_df.columns] if missing: raise KeyError(f"against keys not numeric / not found: {missing}") rows = [] for f in focus_l: x = obs_df[f].to_numpy(dtype=float) for a in against_l: y = obs_df[a].to_numpy(dtype=float) r, p, n = _corr_batch_aware( x, y, adata=adata, method=method, batch_key=batch_key, batch_mode=batch_mode, covariates=covariates, ) if min_abs_r is not None and np.isfinite(r) and abs(r) < float(min_abs_r): continue rows.append((f, a, r, p, n)) df = pd.DataFrame(rows, columns=["focus", "against", "r", "pval", "n"]) df["qval"] = _bh_fdr(df["pval"].to_numpy()) score = np.abs(df["r"].to_numpy()) if use_abs else df["r"].to_numpy() df = df[np.isfinite(score)].copy() if df.empty: return df df = df.sort_values("r", key=lambda s: np.abs(s), ascending=False).head(int(top_n)).reset_index(drop=True) df["method"] = method df["batch_key"] = batch_key df["batch_mode"] = batch_mode return df
[docs] def obs_obs_corr_matrix( adata: ad.AnnData, *, focus: str | list[str], against: list[str] | None = None, obs_keys: list[str] | None = None, method: Method = "pearson", drop_na: bool = True, batch_key: str | None = None, batch_mode: BatchMode = "none", covariates: list[str] | None = None, return_p: bool = False, return_q: bool = False, ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame] | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Correlation matrix between numeric obs columns (focus × against). Returns wide matrix (rows=focus, cols=against). Optionally returns p-values and q-values matrices. """ obs_df = _numeric_obs_df(adata, obs_keys=obs_keys) focus_l = _as_list(focus) if against is None: against_l = [c for c in obs_df.columns if c not in focus_l] else: against_l = _as_list(against) for k in focus_l + against_l: if k not in obs_df.columns: raise KeyError(f"'{k}' not found among numeric obs") # compute pairwise R = pd.DataFrame(index=focus_l, columns=against_l, dtype=float) P = pd.DataFrame(index=focus_l, columns=against_l, dtype=float) for f in focus_l: x = obs_df[f].to_numpy(dtype=float) for a in against_l: y = obs_df[a].to_numpy(dtype=float) r, p, _ = _corr_batch_aware( x, y, adata=adata, method=method, batch_key=batch_key, batch_mode=batch_mode, covariates=covariates, ) R.loc[f, a] = r P.loc[f, a] = p if not return_p and not return_q: return R # q-values across all tested entries qvals = _bh_fdr(P.to_numpy().ravel()) Q = pd.DataFrame(qvals.reshape(P.shape), index=P.index, columns=P.columns) if return_p and return_q: return R, P, Q if return_p: return R, P return R, Q
# --------------------------- # Partial correlation convenience # ---------------------------
[docs] def partial_corr( adata: ad.AnnData, *, x: np.ndarray, y: np.ndarray, covariates: list[str], method: Method = "pearson", ) -> tuple[float, float, int]: """ Partial correlation: residualize x and y on covariates (obs columns) then correlate. """ X = _design_matrix_from_obs(adata, covariates=covariates, add_intercept=True) xr = _residualize(x, X) yr = _residualize(y, X) return _corr(xr, yr, method=method)
# --------------------------- # Plotting wrappers # ---------------------------
[docs] def plot_corr_scatter( adata: ad.AnnData, *, x: str, y: str, hue: str | None = None, kind: Literal["obs_obs", "gene_obs"] = "obs_obs", layer: str | None = "log1p_cpm", method: Method = "pearson", batch_key: str | None = None, batch_mode: BatchMode = "none", covariates: list[str] | None = None, figsize: tuple[float, float] = (5.5, 4.5), s: float = 18.0, alpha: float = 0.8, cmap: str = "viridis", palette: str = "tab20", title: str | None = None, show: bool = True, ): """ Scatter plot + correlation in title. kind="obs_obs": x,y are obs numeric columns kind="gene_obs": x is gene, y is obs numeric (or vice versa) """ if kind == "obs_obs": if x not in adata.obs.columns or y not in adata.obs.columns: raise KeyError("x and y must be obs columns for kind='obs_obs'") xv = pd.to_numeric(adata.obs[x], errors="coerce").to_numpy(dtype=float) yv = pd.to_numeric(adata.obs[y], errors="coerce").to_numpy(dtype=float) elif kind == "gene_obs": # allow x=gene, y=obs or vice versa if x in adata.var_names and y in adata.obs.columns: xv = _get_gene_vector(adata, x, layer=layer) yv = pd.to_numeric(adata.obs[y], errors="coerce").to_numpy(dtype=float) elif y in adata.var_names and x in adata.obs.columns: xv = pd.to_numeric(adata.obs[x], errors="coerce").to_numpy(dtype=float) yv = _get_gene_vector(adata, y, layer=layer) else: raise KeyError("For kind='gene_obs', one of (x,y) must be a gene and the other a numeric obs.") else: raise ValueError(f"Unknown kind={kind}") r, p, n = _corr_batch_aware( xv, yv, adata=adata, method=method, batch_key=batch_key, batch_mode=batch_mode, covariates=covariates, ) fig, ax = plt.subplots(figsize=figsize, constrained_layout=True) if hue is None: ax.scatter(xv, yv, s=s, alpha=alpha, edgecolors="none") else: if hue not in adata.obs.columns: raise KeyError(f"hue='{hue}' not in adata.obs") h = adata.obs[hue] if pd.api.types.is_numeric_dtype(h): vals = pd.to_numeric(h, errors="coerce").to_numpy(dtype=float) sc = ax.scatter(xv, yv, c=vals, s=s, alpha=alpha, cmap=cmap, edgecolors="none") cb = fig.colorbar(sc, ax=ax, pad=0.01) cb.set_label(hue) else: cat = pd.Categorical(h.astype(str)) cats = list(cat.categories) if sns is not None: pal = sns.color_palette(palette, n_colors=len(cats)) colors = {c: pal[i] for i, c in enumerate(cats)} else: cm = plt.get_cmap(palette) colors = {c: cm(i % cm.N) for i, c in enumerate(cats)} for c in cats: m = (cat == c) ax.scatter(xv[m], yv[m], s=s, alpha=alpha, edgecolors="none", label=str(c), color=colors[c]) ax.legend(title=hue, bbox_to_anchor=(1.02, 1), loc="upper left", frameon=False) ax.set_xlabel(x) ax.set_ylabel(y) ttl = title or f"{method} r={r:.3f}, p={p:.2e}, n={n}" ax.set_title(ttl) if show: plt.show() return fig, ax
[docs] def plot_corr_heatmap( mat: pd.DataFrame, *, figsize: tuple[float, float] = (7.5, 4.5), cmap: str = "vlag", center: float = 0.0, vmin: float | None = None, vmax: float | None = None, annot: bool = False, fmt: str = ".2f", title: str | None = None, show: bool = True, ): """ Heatmap for a wide correlation matrix (e.g., from obs_obs_corr_matrix). """ fig, ax = plt.subplots(figsize=figsize, constrained_layout=True) if sns is not None: sns.heatmap(mat, ax=ax, cmap=cmap, center=center, vmin=vmin, vmax=vmax, annot=annot, fmt=fmt) else: im = ax.imshow(mat.to_numpy(dtype=float), aspect="auto") fig.colorbar(im, ax=ax) ax.set_xticks(np.arange(mat.shape[1])) ax.set_yticks(np.arange(mat.shape[0])) ax.set_xticklabels(mat.columns, rotation=90) ax.set_yticklabels(mat.index) if title is not None: ax.set_title(title) if show: plt.show() return fig, ax