Source code for spatialperturb.schema
"""Schema helpers for AnnData objects consumed by SpatialPerturb."""
from __future__ import annotations
from typing import Any
import numpy as np
import pandas as pd
from anndata import AnnData
from ._utils import merge_uns_dict
REQUIRED_OBS_COLUMNS = ("perturbation", "perturbation_status", "cell_type", "roi")
DEFAULT_OBS_VALUES = {
"perturbation": "unassigned",
"perturbation_status": "unassigned",
"cell_type": "unknown",
"roi": "global",
}
SPATIAL_KEY = "spatial"
GRAPH_KEYS = ("sp_knn", "sp_radius")
SPATIALPERTURB_UNS_KEY = "spatialperturb"
SCHEMA_VERSION = "0.3.0"
[docs]
class SpatialPerturbSchemaError(ValueError):
"""Raised when an AnnData object does not satisfy the SpatialPerturb schema."""
def _default_spatial_coords(n_obs: int) -> np.ndarray:
x = np.arange(n_obs, dtype=float)
y = np.zeros(n_obs, dtype=float)
return np.column_stack([x, y])
[docs]
def ensure_spatialperturb_schema(
adata: AnnData,
*,
metadata: dict[str, Any] | None = None,
copy: bool = False,
) -> AnnData:
"""Ensure required schema fields exist, filling defaults when needed."""
target = adata.copy() if copy else adata
for column, default_value in DEFAULT_OBS_VALUES.items():
if column not in target.obs:
target.obs[column] = default_value
target.obs[column] = target.obs[column].astype(str)
if SPATIAL_KEY not in target.obsm:
target.obsm[SPATIAL_KEY] = _default_spatial_coords(target.n_obs)
spatial = np.asarray(target.obsm[SPATIAL_KEY], dtype=float)
if spatial.ndim != 2 or spatial.shape[1] != 2:
raise SpatialPerturbSchemaError(
f"obsm['{SPATIAL_KEY}'] must be an (n_obs, 2) array, found shape {spatial.shape}."
)
if SPATIALPERTURB_UNS_KEY not in target.uns or not isinstance(target.uns[SPATIALPERTURB_UNS_KEY], dict):
target.uns[SPATIALPERTURB_UNS_KEY] = {}
merge_uns_dict(
target.uns[SPATIALPERTURB_UNS_KEY],
{
"schema_version": SCHEMA_VERSION,
"graph_keys": [key for key in GRAPH_KEYS if key in target.obsp],
},
)
merge_uns_dict(target.uns[SPATIALPERTURB_UNS_KEY], metadata)
return target
[docs]
def validate_spatialperturb_schema(
adata: AnnData,
*,
require_graph: bool = False,
) -> None:
"""Validate that an AnnData object follows the SpatialPerturb schema."""
missing_obs = [column for column in REQUIRED_OBS_COLUMNS if column not in adata.obs]
if missing_obs:
raise SpatialPerturbSchemaError(f"Missing required obs columns: {missing_obs}")
if SPATIAL_KEY not in adata.obsm:
raise SpatialPerturbSchemaError("Missing required obsm['spatial'] coordinates.")
spatial = np.asarray(adata.obsm[SPATIAL_KEY])
if spatial.ndim != 2 or spatial.shape[0] != adata.n_obs or spatial.shape[1] != 2:
raise SpatialPerturbSchemaError(
f"obsm['{SPATIAL_KEY}'] must have shape (n_obs, 2); found {spatial.shape}."
)
if SPATIALPERTURB_UNS_KEY not in adata.uns or not isinstance(adata.uns[SPATIALPERTURB_UNS_KEY], dict):
raise SpatialPerturbSchemaError("Missing required uns['spatialperturb'] metadata dictionary.")
if require_graph and not any(key in adata.obsp for key in GRAPH_KEYS):
raise SpatialPerturbSchemaError(
"A spatial graph is required but neither obsp['sp_knn'] nor obsp['sp_radius'] is available."
)
[docs]
def schema_summary(adata: AnnData) -> pd.Series:
"""Return a concise schema summary for validation reports."""
validate_spatialperturb_schema(adata, require_graph=False)
return pd.Series(
{
"n_obs": adata.n_obs,
"n_vars": adata.n_vars,
"has_sp_knn": "sp_knn" in adata.obsp,
"has_sp_radius": "sp_radius" in adata.obsp,
"perturbations": int(adata.obs["perturbation"].nunique()),
"cell_types": int(adata.obs["cell_type"].nunique()),
"rois": int(adata.obs["roi"].nunique()),
}
)