"""Core analysis tools for spatial perturbation inference."""
from __future__ import annotations
from collections.abc import Sequence
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from anndata import AnnData
from scipy import sparse
from scipy.stats import mannwhitneyu, pearsonr, spearmanr, ttest_ind
import statsmodels.api as sm
from ._utils import benjamini_hochberg, extract_matrix, safe_log2_fold_change
from .gr import collect_neighbors
from .schema import ensure_spatialperturb_schema
from .signatures import compare_program_concordance, derive_perturbation_programs
_FALLBACK_LR_NETWORK = pd.DataFrame(
[
{"ligand": "CXCL10", "receptor": "CXCR3", "resource": "fallback"},
{"ligand": "GAS6", "receptor": "AXL", "resource": "fallback"},
{"ligand": "NRG1", "receptor": "ERBB4", "resource": "fallback"},
{"ligand": "TGFB1", "receptor": "TGFBR1", "resource": "fallback"},
{"ligand": "LIG1", "receptor": "REC1", "resource": "fallback"},
]
)
def _default_feature_names(adata: AnnData, features: Sequence[str] | None) -> list[str]:
if features is not None:
return [str(feature) for feature in features if str(feature) in adata.var_names]
barcode_columns = set(adata.uns.get("spatialperturb", {}).get("barcode_columns", []))
return [str(feature) for feature in adata.var_names if str(feature) not in barcode_columns]
def _restrict_mask(
adata: AnnData,
*,
cell_type: str | Sequence[str] | None = None,
roi: str | Sequence[str] | None = None,
) -> np.ndarray:
mask = np.ones(adata.n_obs, dtype=bool)
if cell_type is not None:
cell_types = {cell_type} if isinstance(cell_type, str) else set(map(str, cell_type))
mask &= adata.obs["cell_type"].astype(str).isin(cell_types).to_numpy()
if roi is not None:
rois = {roi} if isinstance(roi, str) else set(map(str, roi))
mask &= adata.obs["roi"].astype(str).isin(rois).to_numpy()
return mask
def _group_mask(
adata: AnnData,
*,
perturbation: str,
groupby: str,
status_col: str,
) -> np.ndarray:
include = (
adata.obs["include_for_inference"].astype(bool).to_numpy()
if "include_for_inference" in adata.obs
else np.ones(adata.n_obs, dtype=bool)
)
return (
(adata.obs[groupby].astype(str) == str(perturbation))
& (adata.obs[status_col].astype(str) == "single")
& include
).to_numpy()
def _mode_or_first(series: pd.Series) -> object:
mode = series.mode(dropna=True)
if not mode.empty:
return mode.iloc[0]
return series.iloc[0] if len(series) else np.nan
def _build_de_output(
*,
perturbation: str,
control: str,
effect_type: str,
features: list[str],
case_mean: np.ndarray,
control_mean: np.ndarray,
statistics: np.ndarray,
pvalues: np.ndarray,
case_n: int,
control_n: int,
method: str,
case_sample_n: int | None = None,
control_sample_n: int | None = None,
) -> pd.DataFrame:
fdr = benjamini_hochberg(pvalues)
output = pd.DataFrame(
{
"perturbation": perturbation,
"control": control,
"effect_type": effect_type,
"method": method,
"gene": features,
"case_n": case_n,
"control_n": control_n,
"case_sample_n": case_sample_n,
"control_sample_n": control_sample_n,
"mean_case": case_mean,
"mean_control": control_mean,
"log2fc": safe_log2_fold_change(case_mean, control_mean),
"statistic": statistics,
"pvalue": pvalues,
"fdr": fdr,
}
)
return output.sort_values(["fdr", "pvalue", "log2fc"], ascending=[True, True, False]).reset_index(drop=True)
def _run_simple_de_matrix(
*,
case_matrix: np.ndarray,
control_matrix: np.ndarray,
features: list[str],
perturbation: str,
control: str,
effect_type: str,
effect_size_only: bool = False,
) -> pd.DataFrame:
case_mean = case_matrix.mean(axis=0)
control_mean = control_matrix.mean(axis=0)
statistics = np.zeros(len(features), dtype=float)
pvalues = np.ones(len(features), dtype=float)
if not effect_size_only:
for idx in range(len(features)):
case_gene = case_matrix[:, idx]
control_gene = control_matrix[:, idx]
if case_gene.shape == control_gene.shape and np.allclose(case_gene, control_gene):
continue
result = mannwhitneyu(case_gene, control_gene, alternative="two-sided")
statistics[idx] = float(result.statistic)
pvalues[idx] = float(result.pvalue)
return _build_de_output(
perturbation=perturbation,
control=control,
effect_type=effect_type,
features=features,
case_mean=case_mean,
control_mean=control_mean,
statistics=statistics,
pvalues=pvalues,
case_n=case_matrix.shape[0],
control_n=control_matrix.shape[0],
method="simple",
)
def _mean_expression_by_mask(
adata: AnnData,
mask: np.ndarray,
*,
features: Sequence[str],
layer: str | None = None,
) -> np.ndarray:
view = adata[mask, list(features)]
matrix = view.layers[layer] if layer is not None else view.X
if sparse.issparse(matrix):
return np.asarray(matrix.mean(axis=0)).ravel()
return np.asarray(matrix, dtype=float).mean(axis=0)
def _aggregate_pseudobulk(
matrix: np.ndarray,
obs: pd.DataFrame,
*,
sample_col: str,
features: list[str],
label: str,
covariates: Sequence[str] | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
sample_ids = obs[sample_col].astype(str).rename("sample")
frame = pd.DataFrame(matrix, index=obs.index, columns=features)
aggregated = frame.groupby(sample_ids).sum()
aggregated.index = [f"{label}:{sample}" for sample in aggregated.index]
meta_cols = [sample_col, *(covariates or [])]
meta = obs.loc[:, [col for col in meta_cols if col in obs.columns]].copy()
meta["_sample"] = sample_ids.values
sample_meta = meta.groupby("_sample").agg(_mode_or_first)
sample_meta.index = aggregated.index
sample_meta["sample"] = [idx.split(":", 1)[1] for idx in aggregated.index]
sample_meta["label"] = label
return aggregated, sample_meta
def _run_pseudobulk_de_matrix(
*,
case_matrix: np.ndarray,
control_matrix: np.ndarray,
case_obs: pd.DataFrame,
control_obs: pd.DataFrame,
features: list[str],
perturbation: str,
control: str,
effect_type: str,
sample_col: str,
covariates: Sequence[str] | None = None,
effect_size_only: bool = False,
) -> pd.DataFrame:
case_counts, case_meta = _aggregate_pseudobulk(
case_matrix,
case_obs,
sample_col=sample_col,
features=features,
label="case",
covariates=covariates,
)
control_counts, control_meta = _aggregate_pseudobulk(
control_matrix,
control_obs,
sample_col=sample_col,
features=features,
label="control",
covariates=covariates,
)
combined_counts = pd.concat([case_counts, control_counts], axis=0)
lib_size = combined_counts.sum(axis=1).replace(0, 1.0)
normalized = np.log1p(combined_counts.div(lib_size, axis=0) * 1_000_000.0)
design_meta = pd.concat([case_meta, control_meta], axis=0)
design_meta["condition"] = [1] * len(case_meta) + [0] * len(control_meta)
design = pd.DataFrame({"condition": design_meta["condition"].astype(float)}, index=normalized.index)
if covariates:
encoded = pd.get_dummies(design_meta.loc[:, list(covariates)], drop_first=True, dtype=float)
design = pd.concat([design, encoded], axis=1)
design = sm.add_constant(design, has_constant="add")
statistics = np.zeros(len(features), dtype=float)
pvalues = np.ones(len(features), dtype=float)
if not effect_size_only:
for idx, gene in enumerate(features):
y = normalized[gene].astype(float)
try:
fit = sm.OLS(y, design).fit()
statistics[idx] = float(fit.tvalues.get("condition", np.nan))
pvalues[idx] = float(fit.pvalues.get("condition", np.nan))
except Exception:
fallback = ttest_ind(
normalized.loc[case_counts.index, gene],
normalized.loc[control_counts.index, gene],
equal_var=False,
nan_policy="omit",
)
statistics[idx] = float(fallback.statistic)
pvalues[idx] = float(fallback.pvalue)
return _build_de_output(
perturbation=perturbation,
control=control,
effect_type=effect_type,
features=features,
case_mean=case_counts.mean(axis=0).to_numpy(),
control_mean=control_counts.mean(axis=0).to_numpy(),
statistics=statistics,
pvalues=pvalues,
case_n=case_matrix.shape[0],
control_n=control_matrix.shape[0],
case_sample_n=len(case_counts),
control_sample_n=len(control_counts),
method="pseudobulk",
)
def _run_de(
adata: AnnData,
*,
case_cells: Sequence[str] | None = None,
control_cells: Sequence[str] | None = None,
case_matrix: np.ndarray | None = None,
control_matrix: np.ndarray | None = None,
case_obs: pd.DataFrame | None = None,
control_obs: pd.DataFrame | None = None,
perturbation: str,
control: str,
effect_type: str,
features: Sequence[str] | None = None,
layer: str | None = None,
method: str = "simple",
sample_col: str | None = None,
covariates: Sequence[str] | None = None,
effect_size_only: bool = False,
) -> pd.DataFrame:
resolved_features = _default_feature_names(adata, features)
if not resolved_features:
raise ValueError("No features available for differential expression.")
if case_matrix is None or control_matrix is None:
if case_cells is None or control_cells is None:
raise ValueError("Either cells or explicit matrices must be provided.")
case_matrix = extract_matrix(adata, layer=layer, obs_names=case_cells, var_names=resolved_features)
control_matrix = extract_matrix(adata, layer=layer, obs_names=control_cells, var_names=resolved_features)
case_obs = adata.obs.loc[list(case_cells)].copy()
control_obs = adata.obs.loc[list(control_cells)].copy()
else:
if case_obs is None or control_obs is None:
raise ValueError("case_obs and control_obs are required when explicit matrices are provided.")
case_obs = case_obs.copy()
control_obs = control_obs.copy()
if method == "simple":
return _run_simple_de_matrix(
case_matrix=case_matrix,
control_matrix=control_matrix,
features=resolved_features,
perturbation=perturbation,
control=control,
effect_type=effect_type,
effect_size_only=effect_size_only,
)
if method == "pseudobulk":
if sample_col is None:
raise ValueError("sample_col is required when method='pseudobulk'.")
return _run_pseudobulk_de_matrix(
case_matrix=case_matrix,
control_matrix=control_matrix,
case_obs=case_obs,
control_obs=control_obs,
features=resolved_features,
perturbation=perturbation,
control=control,
effect_type=effect_type,
sample_col=sample_col,
covariates=covariates,
effect_size_only=effect_size_only,
)
raise ValueError("method must be 'simple' or 'pseudobulk'.")
[docs]
def intrinsic_de(
adata: AnnData,
*,
perturbation: str,
control: str,
groupby: str = "perturbation",
status_col: str = "perturbation_status",
features: Sequence[str] | None = None,
layer: str | None = None,
cell_type: str | Sequence[str] | None = None,
roi: str | Sequence[str] | None = None,
min_cells_per_group: int = 2,
min_samples_per_group: int = 2,
method: str = "simple",
sample_col: str | None = None,
covariates: Sequence[str] | None = None,
effect_size_only: bool = False,
min_cells: int | None = None,
) -> pd.DataFrame:
"""Compare perturbed cells against control cells in the same ROI/cell type context."""
ensure_spatialperturb_schema(adata)
if min_cells is not None:
min_cells_per_group = int(min_cells)
base_mask = _restrict_mask(adata, cell_type=cell_type, roi=roi)
case_mask = _group_mask(adata, perturbation=perturbation, groupby=groupby, status_col=status_col) & base_mask
control_mask = _group_mask(adata, perturbation=control, groupby=groupby, status_col=status_col) & base_mask
if case_mask.sum() < min_cells_per_group or control_mask.sum() < min_cells_per_group:
raise ValueError(
"Need at least "
f"{min_cells_per_group} cells in both case and control groups, found "
f"{case_mask.sum()} and {control_mask.sum()}."
)
if method == "pseudobulk":
if sample_col is None:
raise ValueError("sample_col is required when method='pseudobulk'.")
case_sample_n = adata.obs.loc[case_mask, sample_col].astype(str).nunique()
control_sample_n = adata.obs.loc[control_mask, sample_col].astype(str).nunique()
if case_sample_n < min_samples_per_group or control_sample_n < min_samples_per_group:
raise ValueError(
"Need at least "
f"{min_samples_per_group} samples in both case and control groups for pseudobulk, found "
f"{case_sample_n} and {control_sample_n}."
)
if method == "simple" and effect_size_only:
resolved_features = _default_feature_names(adata, features)
case_mean = _mean_expression_by_mask(adata, case_mask, features=resolved_features, layer=layer)
control_mean = _mean_expression_by_mask(adata, control_mask, features=resolved_features, layer=layer)
return _build_de_output(
perturbation=perturbation,
control=control,
effect_type="intrinsic",
features=resolved_features,
case_mean=case_mean,
control_mean=control_mean,
statistics=np.zeros(len(resolved_features), dtype=float),
pvalues=np.ones(len(resolved_features), dtype=float),
case_n=int(case_mask.sum()),
control_n=int(control_mask.sum()),
method="simple",
)
return _run_de(
adata,
case_cells=adata.obs_names[case_mask],
control_cells=adata.obs_names[control_mask],
perturbation=perturbation,
control=control,
effect_type="intrinsic",
features=features,
layer=layer,
method=method,
sample_col=sample_col,
covariates=covariates,
effect_size_only=effect_size_only,
)
[docs]
def neighbor_de(
adata: AnnData,
*,
perturbation: str,
control: str,
graph_key: str | None = None,
groupby: str = "perturbation",
status_col: str = "perturbation_status",
exclude_perturbed: bool = True,
drop_shared_neighbors: bool = False,
weight_by_distance: bool = False,
aggregate: str = "mean",
features: Sequence[str] | None = None,
layer: str | None = None,
cell_type: str | Sequence[str] | None = None,
roi: str | Sequence[str] | None = None,
min_cells_per_group: int = 2,
min_samples_per_group: int = 2,
method: str = "simple",
sample_col: str | None = None,
covariates: Sequence[str] | None = None,
effect_size_only: bool = False,
min_cells: int | None = None,
) -> pd.DataFrame:
"""Compare neighbors around perturbed cells against neighbors around control cells."""
ensure_spatialperturb_schema(adata)
if min_cells is not None:
min_cells_per_group = int(min_cells)
if aggregate not in {"mean", "sum", "pseudobulk"}:
raise ValueError("aggregate must be one of {'mean', 'sum', 'pseudobulk'}.")
base_mask = _restrict_mask(adata, cell_type=cell_type, roi=roi)
neighbor_mask = _restrict_mask(adata, roi=roi) if roi is not None else np.ones(adata.n_obs, dtype=bool)
allowed_neighbors = set(adata.obs_names[neighbor_mask])
case_sources = adata.obs_names[
_group_mask(adata, perturbation=perturbation, groupby=groupby, status_col=status_col) & base_mask
]
control_sources = adata.obs_names[
_group_mask(adata, perturbation=control, groupby=groupby, status_col=status_col) & base_mask
]
case_neighbors = collect_neighbors(
adata,
cells=case_sources,
graph_key=graph_key,
exclude_perturbed=exclude_perturbed,
as_frame=True,
)
control_neighbors = collect_neighbors(
adata,
cells=control_sources,
graph_key=graph_key,
exclude_perturbed=exclude_perturbed,
as_frame=True,
)
if not case_neighbors.empty:
case_neighbors = case_neighbors[case_neighbors["neighbor"].isin(allowed_neighbors)].copy()
if not control_neighbors.empty:
control_neighbors = control_neighbors[control_neighbors["neighbor"].isin(allowed_neighbors)].copy()
case_cells = set(case_neighbors["neighbor"]) if not case_neighbors.empty else set()
control_cells = set(control_neighbors["neighbor"]) if not control_neighbors.empty else set()
shared = case_cells & control_cells
if drop_shared_neighbors:
case_neighbors = case_neighbors[~case_neighbors["neighbor"].isin(shared)].copy()
control_neighbors = control_neighbors[~control_neighbors["neighbor"].isin(shared)].copy()
resolved_features = _default_feature_names(adata, features)
if not resolved_features:
raise ValueError("No features available for differential expression.")
def _aggregate_frame(frame: pd.DataFrame, sources: Sequence[str]) -> tuple[np.ndarray, pd.DataFrame]:
if frame.empty:
return np.empty((0, len(resolved_features)), dtype=float), adata.obs.iloc[0:0].copy()
if aggregate == "pseudobulk":
if sample_col is None:
raise ValueError("sample_col is required when aggregate='pseudobulk'.")
records: list[np.ndarray] = []
obs_records: list[pd.Series] = []
source_obs = adata.obs.loc[list(sources)].copy()
for sample_id, sample_sources in source_obs.groupby(sample_col):
source_names = set(sample_sources.index.astype(str))
sample_edges = frame[frame["source"].isin(source_names)]
if sample_edges.empty:
continue
neighbor_names = list(dict.fromkeys(sample_edges["neighbor"].astype(str)))
expr = extract_matrix(adata, layer=layer, obs_names=neighbor_names, var_names=resolved_features)
if weight_by_distance:
min_distance = sample_edges.groupby("neighbor")["distance"].min().reindex(neighbor_names)
weights = 1.0 / np.maximum(min_distance.to_numpy(dtype=float), 1e-6)
row = (expr * weights[:, None]).sum(axis=0)
else:
row = expr.sum(axis=0)
records.append(np.asarray(row, dtype=float))
meta = sample_sources.iloc[0].copy()
meta[sample_col] = str(sample_id)
meta["n_neighbors"] = int(len(neighbor_names))
meta["n_sources"] = int(len(source_names))
obs_records.append(meta)
if not records:
return np.empty((0, len(resolved_features)), dtype=float), adata.obs.iloc[0:0].copy()
return np.vstack(records), pd.DataFrame(obs_records)
records = []
obs_records = []
for source, edges in frame.groupby("source"):
neighbor_names = edges["neighbor"].astype(str).tolist()
if not neighbor_names:
continue
expr = extract_matrix(adata, layer=layer, obs_names=neighbor_names, var_names=resolved_features)
if weight_by_distance:
weights = 1.0 / np.maximum(edges["distance"].to_numpy(dtype=float), 1e-6)
if aggregate == "mean":
row = np.average(expr, axis=0, weights=weights)
else:
row = (expr * weights[:, None]).sum(axis=0)
else:
row = expr.mean(axis=0) if aggregate == "mean" else expr.sum(axis=0)
meta = adata.obs.loc[str(source)].copy()
meta["n_neighbors"] = int(len(neighbor_names))
obs_records.append(meta)
records.append(np.asarray(row, dtype=float))
if not records:
return np.empty((0, len(resolved_features)), dtype=float), adata.obs.iloc[0:0].copy()
return np.vstack(records), pd.DataFrame(obs_records)
case_matrix, case_obs = _aggregate_frame(case_neighbors, case_sources)
control_matrix, control_obs = _aggregate_frame(control_neighbors, control_sources)
if case_matrix.shape[0] < min_cells_per_group or control_matrix.shape[0] < min_cells_per_group:
raise ValueError(
"Need at least "
f"{min_cells_per_group} aggregated neighbor profiles in both case and control groups, found "
f"{case_matrix.shape[0]} and {control_matrix.shape[0]}."
)
if method == "pseudobulk":
if sample_col is None:
raise ValueError("sample_col is required when method='pseudobulk'.")
case_sample_n = case_obs[sample_col].astype(str).nunique()
control_sample_n = control_obs[sample_col].astype(str).nunique()
if case_sample_n < min_samples_per_group or control_sample_n < min_samples_per_group:
raise ValueError(
"Need at least "
f"{min_samples_per_group} samples in both case and control neighbor groups for pseudobulk, found "
f"{case_sample_n} and {control_sample_n}."
)
output = _run_de(
adata,
case_matrix=case_matrix,
control_matrix=control_matrix,
case_obs=case_obs,
control_obs=control_obs,
perturbation=perturbation,
control=control,
effect_type="neighbor",
features=resolved_features,
layer=layer,
method=method,
sample_col=sample_col,
covariates=covariates,
effect_size_only=effect_size_only,
)
output["aggregate"] = aggregate
output["weight_by_distance"] = weight_by_distance
output["case_source_n"] = len(case_sources)
output["control_source_n"] = len(control_sources)
output["case_neighbor_n"] = len(set(case_neighbors["neighbor"])) if not case_neighbors.empty else 0
output["control_neighbor_n"] = len(set(control_neighbors["neighbor"])) if not control_neighbors.empty else 0
output["shared_neighbors_removed"] = len(shared) if drop_shared_neighbors else 0
return output
def _load_lr_network(lr_network: str | pd.DataFrame | None = None) -> tuple[pd.DataFrame, str]:
if isinstance(lr_network, pd.DataFrame):
required = {"ligand", "receptor"}
missing = required.difference(lr_network.columns)
if missing:
raise ValueError(f"lr_network must contain columns {required}, missing {missing}")
source = "custom"
if "resource" in lr_network.columns:
resources = sorted({str(value) for value in lr_network["resource"].dropna().astype(str)})
if resources:
source = ",".join(resources)
return lr_network.copy(), source
if lr_network is None or lr_network == "fallback":
return _FALLBACK_LR_NETWORK.copy(), "fallback"
if lr_network == "custom":
raise ValueError("When lr_network='custom', pass a DataFrame or a file path.")
path = Path(str(lr_network)).expanduser()
if not path.exists():
raise FileNotFoundError(f"Ligand-receptor network file not found: {path}")
if path.suffix.lower() in {".tsv", ".txt"}:
table = pd.read_csv(path, sep="\t")
else:
table = pd.read_csv(path)
required = {"ligand", "receptor"}
missing = required.difference(table.columns)
if missing:
raise ValueError(f"lr_network must contain columns {required}, missing {missing}")
return table, str(path)
[docs]
def differential_lr(
adata: AnnData,
*,
perturbation: str,
control: str,
graph_key: str | None = None,
lr_network: str | pd.DataFrame | None = None,
groupby: str = "perturbation",
status_col: str = "perturbation_status",
source_groupby: str | None = None,
target_groupby: str | None = None,
layer: str | None = None,
cell_type: str | Sequence[str] | None = None,
roi: str | Sequence[str] | None = None,
) -> pd.DataFrame:
"""Score ligand-receptor interactions between perturbed cells and their neighbors."""
ensure_spatialperturb_schema(adata)
base_mask = _restrict_mask(adata, cell_type=cell_type, roi=roi)
neighbor_mask = _restrict_mask(adata, roi=roi) if roi is not None else np.ones(adata.n_obs, dtype=bool)
case_sources = adata.obs_names[
_group_mask(adata, perturbation=perturbation, groupby=groupby, status_col=status_col) & base_mask
]
control_sources = adata.obs_names[
_group_mask(adata, perturbation=control, groupby=groupby, status_col=status_col) & base_mask
]
case_neighbors = collect_neighbors(adata, cells=case_sources, graph_key=graph_key, exclude_perturbed=True, as_frame=True)
control_neighbors = collect_neighbors(
adata, cells=control_sources, graph_key=graph_key, exclude_perturbed=True, as_frame=True
)
allowed_neighbors = set(adata.obs_names[neighbor_mask])
if not case_neighbors.empty:
case_neighbors = case_neighbors[case_neighbors["neighbor"].isin(allowed_neighbors)].copy()
if not control_neighbors.empty:
control_neighbors = control_neighbors[control_neighbors["neighbor"].isin(allowed_neighbors)].copy()
for column in [source_groupby, target_groupby]:
if column is not None and column not in adata.obs.columns:
raise KeyError(f"{column!r} not found in adata.obs.")
network, network_source = _load_lr_network(lr_network)
ligands = sorted(set(network["ligand"]).intersection(adata.var_names))
receptors = sorted(set(network["receptor"]).intersection(adata.var_names))
if not ligands or not receptors:
return pd.DataFrame(
columns=[
"perturbation",
"control",
"network_source",
"source_group",
"target_group",
"ligand",
"receptor",
"case_score",
"control_score",
"diff_score",
"rank",
]
)
records = []
source_groups = [None]
if source_groupby is not None:
source_groups = sorted(
{
*adata.obs.loc[list(case_sources), source_groupby].astype(str),
*adata.obs.loc[list(control_sources), source_groupby].astype(str),
}
)
target_groups = [None]
if target_groupby is not None:
case_neighbor_cells = case_neighbors["neighbor"].astype(str).tolist() if not case_neighbors.empty else []
control_neighbor_cells = control_neighbors["neighbor"].astype(str).tolist() if not control_neighbors.empty else []
target_groups = sorted(
{
*adata.obs.loc[list(dict.fromkeys(case_neighbor_cells)), target_groupby].astype(str),
*adata.obs.loc[list(dict.fromkeys(control_neighbor_cells)), target_groupby].astype(str),
}
)
for source_group in source_groups:
case_source_subset = list(case_sources)
control_source_subset = list(control_sources)
if source_groupby is not None:
case_source_subset = adata.obs.loc[list(case_sources)]
case_source_subset = case_source_subset[case_source_subset[source_groupby].astype(str) == str(source_group)].index.tolist()
control_source_subset = adata.obs.loc[list(control_sources)]
control_source_subset = control_source_subset[
control_source_subset[source_groupby].astype(str) == str(source_group)
].index.tolist()
if not case_source_subset or not control_source_subset:
continue
case_edge_subset = case_neighbors[case_neighbors["source"].isin(case_source_subset)].copy()
control_edge_subset = control_neighbors[control_neighbors["source"].isin(control_source_subset)].copy()
for target_group in target_groups:
case_neighbor_subset = (
list(dict.fromkeys(case_edge_subset["neighbor"].astype(str)))
if not case_edge_subset.empty
else []
)
control_neighbor_subset = (
list(dict.fromkeys(control_edge_subset["neighbor"].astype(str)))
if not control_edge_subset.empty
else []
)
if target_groupby is not None:
case_neighbor_subset = (
adata.obs.loc[case_neighbor_subset]
.loc[lambda df: df[target_groupby].astype(str) == str(target_group)]
.index.tolist()
)
control_neighbor_subset = (
adata.obs.loc[control_neighbor_subset]
.loc[lambda df: df[target_groupby].astype(str) == str(target_group)]
.index.tolist()
)
if not case_neighbor_subset or not control_neighbor_subset:
continue
case_source_expr = pd.Series(
extract_matrix(adata, layer=layer, obs_names=case_source_subset, var_names=ligands).mean(axis=0),
index=ligands,
)
control_source_expr = pd.Series(
extract_matrix(adata, layer=layer, obs_names=control_source_subset, var_names=ligands).mean(axis=0),
index=ligands,
)
case_neighbor_expr = pd.Series(
extract_matrix(adata, layer=layer, obs_names=case_neighbor_subset, var_names=receptors).mean(axis=0),
index=receptors,
)
control_neighbor_expr = pd.Series(
extract_matrix(adata, layer=layer, obs_names=control_neighbor_subset, var_names=receptors).mean(axis=0),
index=receptors,
)
for row in network[["ligand", "receptor"]].drop_duplicates().itertuples(index=False):
if row.ligand not in case_source_expr.index or row.receptor not in case_neighbor_expr.index:
continue
case_score = float(case_source_expr[row.ligand] * case_neighbor_expr[row.receptor])
control_score = float(control_source_expr[row.ligand] * control_neighbor_expr[row.receptor])
records.append(
{
"perturbation": perturbation,
"control": control,
"network_source": network_source,
"source_group": "all" if source_group is None else str(source_group),
"target_group": "all" if target_group is None else str(target_group),
"ligand": row.ligand,
"receptor": row.receptor,
"case_source_n": len(case_source_subset),
"control_source_n": len(control_source_subset),
"case_neighbor_n": len(case_neighbor_subset),
"control_neighbor_n": len(control_neighbor_subset),
"case_score": case_score,
"control_score": control_score,
"diff_score": case_score - control_score,
}
)
output = pd.DataFrame.from_records(records)
if output.empty:
return output
output = output.sort_values(
["source_group", "target_group", "diff_score"],
ascending=[True, True, False],
).reset_index(drop=True)
output["rank"] = np.arange(1, len(output) + 1)
return output
[docs]
def power_curve(
adata: AnnData,
*,
perturbation: str,
control: str,
mode: str = "intrinsic",
feature: str | None = None,
sample_sizes: Sequence[int] = (5, 10, 20, 30),
n_boot: int = 100,
alpha: float = 0.05,
groupby: str = "perturbation",
status_col: str = "perturbation_status",
graph_key: str | None = None,
layer: str | None = None,
method: str = "simple",
sample_col: str | None = None,
cell_type: str | Sequence[str] | None = None,
roi: str | Sequence[str] | None = None,
random_state: int = 0,
) -> pd.DataFrame:
"""Estimate detection power across sample sizes by bootstrap resampling."""
ensure_spatialperturb_schema(adata)
rng = np.random.default_rng(random_state)
base_mask = _restrict_mask(adata, cell_type=cell_type, roi=roi)
neighbor_mask = _restrict_mask(adata, roi=roi) if roi is not None else np.ones(adata.n_obs, dtype=bool)
if mode == "intrinsic":
case_pool = adata.obs_names[
_group_mask(adata, perturbation=perturbation, groupby=groupby, status_col=status_col) & base_mask
]
control_pool = adata.obs_names[
_group_mask(adata, perturbation=control, groupby=groupby, status_col=status_col) & base_mask
]
elif mode == "neighbor":
case_sources = adata.obs_names[
_group_mask(adata, perturbation=perturbation, groupby=groupby, status_col=status_col) & base_mask
]
control_sources = adata.obs_names[
_group_mask(adata, perturbation=control, groupby=groupby, status_col=status_col) & base_mask
]
case_edges = collect_neighbors(adata, cells=case_sources, graph_key=graph_key, exclude_perturbed=True, as_frame=True)
control_edges = collect_neighbors(adata, cells=control_sources, graph_key=graph_key, exclude_perturbed=True, as_frame=True)
allowed_neighbors = set(adata.obs_names[neighbor_mask])
if not case_edges.empty:
case_edges = case_edges[case_edges["neighbor"].isin(allowed_neighbors)].copy()
if not control_edges.empty:
control_edges = control_edges[control_edges["neighbor"].isin(allowed_neighbors)].copy()
case_pool = pd.Index(sorted(set(case_edges["neighbor"])))
control_pool = pd.Index(sorted(set(control_edges["neighbor"])))
else:
raise ValueError("mode must be 'intrinsic' or 'neighbor'.")
if feature is None:
reference = intrinsic_de(
adata,
perturbation=perturbation,
control=control,
groupby=groupby,
status_col=status_col,
method=method,
sample_col=sample_col,
cell_type=cell_type,
roi=roi,
effect_size_only=True,
)
feature = str(reference.iloc[0]["gene"])
if feature not in adata.var_names:
raise KeyError(f"Feature {feature!r} not found in AnnData.var_names.")
records = []
if method == "pseudobulk":
if sample_col is None:
raise ValueError("sample_col is required when method='pseudobulk'.")
case_values = (
pd.DataFrame(
{
"sample": adata.obs.loc[list(case_pool), sample_col].astype(str).to_numpy(),
"value": extract_matrix(adata, layer=layer, obs_names=case_pool, var_names=[feature]).ravel(),
}
)
.groupby("sample")["value"]
.mean()
.to_numpy(dtype=float)
)
control_values = (
pd.DataFrame(
{
"sample": adata.obs.loc[list(control_pool), sample_col].astype(str).to_numpy(),
"value": extract_matrix(adata, layer=layer, obs_names=control_pool, var_names=[feature]).ravel(),
}
)
.groupby("sample")["value"]
.mean()
.to_numpy(dtype=float)
)
else:
case_values = extract_matrix(adata, layer=layer, obs_names=case_pool, var_names=[feature]).ravel()
control_values = extract_matrix(adata, layer=layer, obs_names=control_pool, var_names=[feature]).ravel()
max_case = len(case_values)
max_control = len(control_values)
for size in sample_sizes:
effective_n = min(int(size), max_case, max_control)
if effective_n < 2:
records.append(
{
"perturbation": perturbation,
"control": control,
"mode": mode,
"feature": feature,
"sample_size": int(size),
"effective_n": effective_n,
"method": method,
"power": np.nan,
}
)
continue
hits = 0
for _ in range(int(n_boot)):
case_sample = rng.choice(case_values, size=effective_n, replace=False)
control_sample = rng.choice(control_values, size=effective_n, replace=False)
if method == "pseudobulk":
pvalue = ttest_ind(case_sample, control_sample, equal_var=False, nan_policy="omit").pvalue
else:
pvalue = mannwhitneyu(case_sample, control_sample, alternative="two-sided").pvalue
if pvalue < alpha:
hits += 1
records.append(
{
"perturbation": perturbation,
"control": control,
"mode": mode,
"feature": feature,
"sample_size": int(size),
"effective_n": effective_n,
"method": method,
"power": hits / float(n_boot),
}
)
return pd.DataFrame.from_records(records)