Source code for bullkpy.pp.batch

from __future__ import annotations

from typing import Sequence

import numpy as np
import pandas as pd
import scipy.sparse as sp
import anndata as ad

from ..logging import info, warn


# -------------------------
# Helpers: ComBat EB priors
# -------------------------
def _aprior(delta_hat: np.ndarray) -> float:
    m = np.mean(delta_hat)
    s2 = np.var(delta_hat, ddof=1)
    return (2.0 * s2 + m**2) / (s2 + 1e-12)


def _bprior(delta_hat: np.ndarray) -> float:
    m = np.mean(delta_hat)
    s2 = np.var(delta_hat, ddof=1)
    return (m * s2 + m**3) / (s2 + 1e-12)


def _it_sol(
    sdat: np.ndarray,
    gamma_hat: np.ndarray,
    delta_hat: np.ndarray,
    gamma_bar: float,
    t2: float,
    a_prior: float,
    b_prior: float,
    *,
    conv: float = 1e-4,
    max_iter: int = 100,
) -> tuple[float, float]:
    """
    Iterative solution for posterior gamma* and delta* for one batch and one gene-set.
    This is the standard parametric ComBat update.
    """
    g_old = gamma_hat
    d_old = delta_hat
    n = sdat.size

    for _ in range(max_iter):
        # posterior mean of gamma (Normal)
        g_new = (t2 * n * gamma_hat + d_old * gamma_bar) / (t2 * n + d_old + 1e-12)

        # posterior mode/mean-ish of delta (Inv-Gamma)
        sum2 = np.sum((sdat - g_new) ** 2)
        d_new = (0.5 * sum2 + b_prior) / (0.5 * n + a_prior - 1.0 + 1e-12)

        if np.max(np.abs(g_new - g_old) / (np.abs(g_old) + 1e-12)) < conv and np.max(
            np.abs(d_new - d_old) / (np.abs(d_old) + 1e-12)
        ) < conv:
            return float(g_new), float(d_new)

        g_old = g_new
        d_old = d_new

    return float(g_old), float(d_old)


# -------------------------
# Main: ComBat for bulk
# -------------------------
[docs] def batch_correct_combat( adata: ad.AnnData, *, batch_key: str, layer: str | None = "log1p_cpm", covariates: Sequence[str] | None = None, key_added: str = "combat", overwrite: bool = False, inplace: bool = True, ) -> np.ndarray | None: """ ComBat batch correction (Johnson et al.) for bulk expression. Notes ----- - ComBat is intended for approximately Gaussian data: use log-transformed normalized expression (e.g., log1p_cpm), not raw counts. - Writes corrected matrix to `adata.layers[key_added]` by default. Parameters ---------- batch_key adata.obs column with batch labels (categorical recommended). layer Which matrix to correct. If None, uses adata.X. covariates Optional adata.obs columns to include in the design (biological covariates to preserve). key_added Layer name to store corrected values (if overwrite=False). overwrite If True, write corrected values back into the selected layer / X. inplace If True, store results in adata; if False, return corrected matrix (samples x genes). """ if batch_key not in adata.obs.columns: raise KeyError(f"batch_key='{batch_key}' not found in adata.obs") covariates = list(covariates) if covariates is not None else [] for c in covariates: if c not in adata.obs.columns: raise KeyError(f"covariate '{c}' not found in adata.obs") # -------- data matrix -------- X = adata.layers[layer] if (layer is not None and layer in adata.layers) else adata.X if sp.issparse(X): X = X.toarray() X = np.asarray(X, dtype=float) # samples x genes n_samples, n_genes = X.shape # -------- batch variable -------- batch = adata.obs[batch_key].astype("category") batches = list(batch.cat.categories) n_batch = len(batches) if n_batch < 2: warn(f"ComBat: batch_key='{batch_key}' has <2 batches; skipping correction.") if inplace: return None return X info( f"ComBat: correcting layer='{layer}' with batch_key='{batch_key}' " f"({n_batch} batches), covariates={covariates}" ) # -------- design matrix: intercept + batch dummies + covariates -------- # Intercept design = pd.DataFrame({"Intercept": np.ones(n_samples, dtype=float)}, index=adata.obs_names) # Covariates (to preserve) for c in covariates: s = adata.obs[c] if pd.api.types.is_numeric_dtype(s): design[c] = s.astype(float).values else: d = pd.get_dummies(s.astype("category"), prefix=c, drop_first=True) design = pd.concat([design, d], axis=1) # Batch dummies (NO drop_first; we need all batches for estimation) batch_dum = pd.get_dummies(batch, prefix="batch", drop_first=False) design = pd.concat([design, batch_dum], axis=1) design_mat = design.to_numpy(dtype=float) # n x p # columns that correspond to batch indicators batch_cols = batch_dum.columns.tolist() batch_col_idx = [design.columns.get_loc(c) for c in batch_cols] # -------- standardize genes -------- # Fit linear model: X = design * B + E (least squares) # B_hat = (D'D)^-1 D'X DtD_inv = np.linalg.pinv(design_mat.T @ design_mat) B_hat = DtD_inv @ (design_mat.T @ X) # p x genes # Grand mean: intercept + covariates (exclude batch effects) # Use design without batch columns for "biological part" keep_cols = [i for i in range(design_mat.shape[1]) if i not in batch_col_idx] design_keep = design_mat[:, keep_cols] B_keep = B_hat[keep_cols, :] # (p_keep x genes) grand_mean = design_keep @ B_keep # samples x genes # Pooled variance of residuals resid = X - (design_mat @ B_hat) var_pooled = np.var(resid, axis=0, ddof=1) var_pooled[var_pooled == 0] = 1.0 sdat = (X - grand_mean) / np.sqrt(var_pooled) # standardized data # -------- estimate batch effects on standardized data -------- gamma_hat = np.zeros((n_batch, n_genes), dtype=float) delta_hat = np.zeros((n_batch, n_genes), dtype=float) batch_indices = [] for i, b in enumerate(batches): idx = np.where(batch.to_numpy() == b)[0] batch_indices.append(idx) gamma_hat[i, :] = np.mean(sdat[idx, :], axis=0) delta_hat[i, :] = np.var(sdat[idx, :], axis=0, ddof=1) delta_hat[i, delta_hat[i, :] == 0] = 1.0 # -------- empirical Bayes shrinkage (parametric) -------- gamma_star = np.zeros_like(gamma_hat) delta_star = np.zeros_like(delta_hat) for i in range(n_batch): g_i = gamma_hat[i, :] d_i = delta_hat[i, :] gamma_bar = float(np.mean(g_i)) t2 = float(np.var(g_i, ddof=1) if n_genes > 1 else 1.0) if t2 == 0: t2 = 1.0 a_prior = float(_aprior(d_i)) b_prior = float(_bprior(d_i)) idx = batch_indices[i] for j in range(n_genes): g, d = _it_sol( sdat[idx, j], g_i[j], d_i[j], gamma_bar, t2, a_prior, b_prior, ) gamma_star[i, j] = g delta_star[i, j] = d # -------- adjust data -------- bayesdata = sdat.copy() for i in range(n_batch): idx = batch_indices[i] bayesdata[idx, :] = (bayesdata[idx, :] - gamma_star[i, :]) / np.sqrt(delta_star[i, :]) # de-standardize corrected = bayesdata * np.sqrt(var_pooled) + grand_mean # samples x genes # -------- store / return -------- if inplace: if overwrite: if layer is None: adata.X = corrected else: if layer not in adata.layers: warn(f"overwrite=True but layer='{layer}' not in adata.layers; writing to adata.layers['{layer}'].") adata.layers[layer] = corrected else: adata.layers[key_added] = corrected adata.uns.setdefault("combat", {}) adata.uns["combat"] = { "params": { "batch_key": batch_key, "layer": layer, "covariates": covariates, "key_added": key_added, "overwrite": overwrite, }, "batches": batches, } info(f"ComBat: stored corrected matrix in {'adata.X' if (overwrite and layer is None) else f'adata.layers[{key_added!r}]' if not overwrite else f'adata.layers[{layer!r}]'}") return None return corrected