"""AnnData I/O helpers for SpatialPerturb."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import anndata as ad
from anndata import AnnData
from matplotlib.path import Path as MplPath
import numpy as np
import pandas as pd
from scipy import sparse
from .schema import ensure_spatialperturb_schema, validate_spatialperturb_schema
def _as_dataframe(
table: pd.DataFrame | None,
*,
index: pd.Index,
default_prefix: str,
) -> pd.DataFrame:
if table is None:
return pd.DataFrame(index=index)
target = table.copy()
target.index = target.index.astype(str)
if not target.index.equals(index):
target = target.reindex(index)
if target.index.isnull().any():
target.index = pd.Index([f"{default_prefix}_{idx}" for idx in range(len(target))], dtype="object")
return target
def _infer_spatial(obs: pd.DataFrame, spatial: Any | None) -> np.ndarray | None:
if spatial is not None:
coords = np.asarray(spatial, dtype=float)
if coords.ndim != 2 or coords.shape[1] != 2:
raise ValueError("spatial coordinates must have shape (n_obs, 2).")
return coords
if {"x", "y"}.issubset(obs.columns):
return obs.loc[:, ["x", "y"]].to_numpy(dtype=float)
if {"X", "Y"}.issubset(obs.columns):
return obs.loc[:, ["X", "Y"]].to_numpy(dtype=float)
if {"x_centroid", "y_centroid"}.issubset(obs.columns):
return obs.loc[:, ["x_centroid", "y_centroid"]].to_numpy(dtype=float)
return None
[docs]
def from_tables(
expression: pd.DataFrame | np.ndarray | sparse.spmatrix,
*,
obs: pd.DataFrame | None = None,
var: pd.DataFrame | None = None,
spatial: pd.DataFrame | np.ndarray | None = None,
layers: dict[str, pd.DataFrame | np.ndarray | sparse.spmatrix] | None = None,
metadata: dict[str, Any] | None = None,
) -> AnnData:
"""Build a schema-compliant AnnData object from tabular inputs."""
if isinstance(expression, pd.DataFrame):
obs_index = pd.Index(expression.index.astype(str), dtype="object")
var_index = pd.Index(expression.columns.astype(str), dtype="object")
matrix = expression.to_numpy(dtype=float)
else:
matrix = expression
if obs is None or var is None:
raise ValueError("obs and var must be provided when expression is not a DataFrame.")
obs_index = pd.Index(obs.index.astype(str), dtype="object")
var_index = pd.Index(var.index.astype(str), dtype="object")
obs_frame = _as_dataframe(obs, index=obs_index, default_prefix="cell")
var_frame = _as_dataframe(var, index=var_index, default_prefix="gene")
if sparse.issparse(matrix):
adata = ad.AnnData(X=matrix.tocsr(), obs=obs_frame, var=var_frame)
else:
adata = ad.AnnData(X=np.asarray(matrix), obs=obs_frame, var=var_frame)
if layers:
for layer_name, layer_values in layers.items():
adata.layers[str(layer_name)] = layer_values.to_numpy() if isinstance(layer_values, pd.DataFrame) else layer_values
coords = _infer_spatial(obs_frame, spatial)
if coords is not None:
adata.obsm["spatial"] = coords
ensure_spatialperturb_schema(adata, metadata=metadata or {})
validate_spatialperturb_schema(adata)
return adata
def _read_expression_table(path: Path) -> pd.DataFrame:
suffixes = "".join(path.suffixes).lower()
sep = "\t" if ".tsv" in suffixes else ","
compression = "gzip" if path.suffix == ".gz" else None
return pd.read_csv(path, sep=sep, index_col=0, compression=compression)
def _read_cells_table(path: Path) -> pd.DataFrame:
suffixes = "".join(path.suffixes).lower()
sep = "\t" if ".tsv" in suffixes else ","
compression = "gzip" if path.suffix == ".gz" else None
table = pd.read_csv(path, sep=sep, index_col=0, compression=compression)
table.index = table.index.astype(str)
return table
def _decode_h5_strings(values: Any) -> list[str]:
array = np.asarray(values)
output: list[str] = []
for value in array:
if isinstance(value, bytes):
output.append(value.decode("utf-8"))
else:
output.append(str(value))
return output
def _read_10x_feature_matrix_h5(matrix_path: Path, *, cells_path: Path | None, platform: str) -> AnnData:
try:
import h5py
except ImportError as exc:
raise ImportError("Reading 10x/Xenium H5 matrices requires the h5py package.") from exc
with h5py.File(matrix_path, "r") as handle:
matrix_group = handle["matrix"]
shape = tuple(int(value) for value in matrix_group["shape"][()])
matrix = sparse.csc_matrix(
(
matrix_group["data"][()],
matrix_group["indices"][()],
matrix_group["indptr"][()],
),
shape=shape,
).transpose().tocsr()
obs_names = pd.Index(_decode_h5_strings(matrix_group["barcodes"][()]), dtype="object")
features = matrix_group["features"]
gene_names = pd.Index(_decode_h5_strings(features["name"][()]), dtype="object")
var = pd.DataFrame(
{
"feature_id": _decode_h5_strings(features["id"][()]),
"feature_type": _decode_h5_strings(features["feature_type"][()]),
},
index=gene_names,
)
if "genome" in features:
var["genome"] = _decode_h5_strings(features["genome"][()])
obs = pd.DataFrame(index=obs_names)
if cells_path is not None and cells_path.exists():
obs = _read_cells_table(cells_path).reindex(obs_names)
adata = from_tables(
matrix,
obs=obs,
var=var,
metadata={
"platform": platform,
"source_path": str(matrix_path.parent),
"matrix_path": str(matrix_path),
"cells_path": None if cells_path is None else str(cells_path),
"reader": "10x_feature_matrix_h5",
},
)
adata.var_names_make_unique()
return adata
def _read_platform_directory(path: str | Path, *, platform: str) -> AnnData:
directory = Path(path).expanduser().resolve()
if directory.is_file():
if directory.suffix == ".h5ad":
adata = ad.read_h5ad(directory)
ensure_spatialperturb_schema(adata, metadata={"platform": platform, "source_path": str(directory)})
validate_spatialperturb_schema(adata)
return adata
if directory.name == "cell_feature_matrix.h5" or directory.suffix == ".h5":
cells_path = directory.parent / "cells.csv.gz"
if not cells_path.exists():
cells_path = directory.parent / "cells.csv"
return _read_10x_feature_matrix_h5(
directory,
cells_path=cells_path if cells_path.exists() else None,
platform=platform,
)
raise ValueError(f"Unsupported file input for {platform}: {directory}")
if not directory.exists():
raise FileNotFoundError(directory)
feature_matrix_path = directory / "cell_feature_matrix.h5"
if feature_matrix_path.exists():
cells_path = directory / "cells.csv.gz"
if not cells_path.exists():
cells_path = directory / "cells.csv"
return _read_10x_feature_matrix_h5(
feature_matrix_path,
cells_path=cells_path if cells_path.exists() else None,
platform=platform,
)
counts_candidates = [
directory / "counts.csv",
directory / "counts.tsv",
directory / "counts.csv.gz",
directory / "counts.tsv.gz",
directory / "matrix.csv",
directory / "matrix.tsv",
directory / "matrix.csv.gz",
directory / "matrix.tsv.gz",
]
cells_candidates = [
directory / "cells.csv",
directory / "cells.tsv",
directory / "cells.csv.gz",
directory / "cells.tsv.gz",
directory / "metadata.csv",
directory / "metadata.tsv",
directory / "metadata.csv.gz",
directory / "metadata.tsv.gz",
]
counts_path = next((candidate for candidate in counts_candidates if candidate.exists()), None)
cells_path = next((candidate for candidate in cells_candidates if candidate.exists()), None)
if counts_path is None or cells_path is None:
raise FileNotFoundError(
f"Could not locate paired counts/cells tables in {directory}. Expected files like counts.csv and cells.csv."
)
expression = _read_expression_table(counts_path)
cells = _read_cells_table(cells_path).reindex(expression.index.astype(str))
adata = from_tables(expression, obs=cells, metadata={"platform": platform, "source_path": str(directory)})
molecules_candidates = [
directory / "molecules.csv",
directory / "molecules.tsv",
directory / "molecules.csv.gz",
directory / "molecules.tsv.gz",
]
molecules_path = next((candidate for candidate in molecules_candidates if candidate.exists()), None)
if molecules_path is not None:
adata.uns.setdefault("spatialperturb", {})
adata.uns["spatialperturb"]["molecules"] = {"path": str(molecules_path)}
return adata
def _annotate_cell_groups(adata: AnnData, path: Path) -> None:
table = pd.read_csv(path)
if "cell_id" not in table.columns or "group" not in table.columns:
raise ValueError("cell_group_path must contain at least 'cell_id' and 'group' columns.")
mapping = dict(zip(table["cell_id"].astype(str), table["group"].astype(str), strict=False))
matched = 0
cell_types = []
for cell_id, current in zip(adata.obs_names.astype(str), adata.obs["cell_type"].astype(str), strict=False):
group = mapping.get(cell_id)
if group is None:
cell_types.append(current)
else:
cell_types.append(group)
matched += 1
adata.obs["cell_type"] = cell_types
adata.uns.setdefault("spatialperturb", {})
adata.uns["spatialperturb"]["cell_group_annotation"] = {"path": str(path), "matched_cells": matched}
def _roi_feature_name(feature: dict[str, Any], default: str) -> str:
properties = feature.get("properties", {})
for key in ("assigned_structure", "name", "label", "group", "roi"):
value = properties.get(key)
if value is not None and str(value).strip():
return str(value)
return default
def _polygon_outer_rings(geometry: dict[str, Any]) -> list[np.ndarray]:
geometry_type = str(geometry.get("type", ""))
coordinates = geometry.get("coordinates", [])
if geometry_type == "Polygon":
return [np.asarray(coordinates[0], dtype=float)] if coordinates else []
if geometry_type == "MultiPolygon":
return [np.asarray(polygon[0], dtype=float) for polygon in coordinates if polygon]
return []
def _annotate_roi_geojson(adata: AnnData, path: Path) -> None:
payload = json.loads(path.read_text(encoding="utf-8"))
features = payload.get("features", [])
coords = np.asarray(adata.obsm["spatial"], dtype=float)
rois = np.asarray(adata.obs["roi"].astype(str).tolist(), dtype=object)
assigned = np.zeros(adata.n_obs, dtype=bool)
for feature_idx, feature in enumerate(features):
if not isinstance(feature, dict):
continue
roi_name = _roi_feature_name(feature, default=f"roi_{feature_idx + 1}")
for ring in _polygon_outer_rings(feature.get("geometry", {})):
if ring.size == 0:
continue
mask = MplPath(ring, closed=True).contains_points(coords) & ~assigned
rois[mask] = roi_name
assigned[mask] = True
adata.obs["roi"] = rois.astype(str)
adata.uns.setdefault("spatialperturb", {})
roi_counts = pd.Series(rois, dtype="object").value_counts().sort_index()
adata.uns["spatialperturb"]["roi_annotation"] = {
"path": str(path),
"matched_cells": int(assigned.sum()),
"total_cells": int(adata.n_obs),
"roi_counts": {str(name): int(count) for name, count in roi_counts.items()},
}
[docs]
def read_xenium(
path: str | Path,
*,
cell_group_path: str | Path | None = None,
roi_geojson_path: str | Path | None = None,
sample_name: str | None = None,
load_molecules: bool = False,
) -> AnnData:
"""Read a simple Xenium-style cell-level export into AnnData."""
adata = _read_platform_directory(path, platform="xenium")
if sample_name is not None:
adata.obs["sample"] = str(sample_name)
adata.uns.setdefault("spatialperturb", {})
adata.uns["spatialperturb"]["sample_name"] = str(sample_name)
if cell_group_path is not None:
_annotate_cell_groups(adata, Path(cell_group_path))
if roi_geojson_path is not None:
_annotate_roi_geojson(adata, Path(roi_geojson_path))
if not load_molecules:
adata.uns.get("spatialperturb", {}).pop("molecules", None)
ensure_spatialperturb_schema(adata, metadata={"platform": "xenium"})
validate_spatialperturb_schema(adata)
return adata
[docs]
def read_stereoseq(path: str | Path, **_: Any) -> AnnData:
"""Read a simple Stereo-seq-style cell-level export into AnnData."""
return _read_platform_directory(path, platform="stereoseq")