"""Consistent-Trees reader.
Inputs (the Rockstar / Consistent-Trees pipeline):
- ``input_path`` : directory containing the ``tree_*.dat`` files
- ``forests_path`` : ``forests.list`` (tree_root_id forest_id [weight])
- ``locations_path`` : ``locations.dat`` (tree_root_id file_id offset filename)
Column lookup is by *name* (parsed from the ``#scale(0) id(1) ...`` header
line) rather than by hardcoded index — this fixes a latent fragility in the
legacy C tool, which silently broke if a column was added or reordered.
"""
from __future__ import annotations
import re
from collections.abc import Iterator
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar, TextIO
import numpy as np
from astrosylva.exceptions import ReaderError
from astrosylva.readers.base import ReaderSource, TreeReader
from astrosylva.schema import DEFAULT_UNITS, HALO_DTYPE, Forest, Metadata
@dataclass
class _ForestEntry:
forest_id: int
weight: float
@dataclass
class _TreeRef:
tree_root_id: int
forest_id: int
forest_weight: float
file_id: int
offset: int
file_path: Path
@dataclass
class _ForestRef:
forest_id: int
weight: float
trees: list[_TreeRef] = field(default_factory=list)
# Column name -> dtype field name. Mass / scale-radius source keys are
# resolved at runtime based on reader options.
_FIXED_COLUMN_MAP: dict[str, str] = {
"scale": "expansionFactor",
"id": "nodeIndex",
"desc_id": "descendantIndex",
"mvir": "nodeMass",
"x": "position.0",
"y": "position.1",
"z": "position.2",
"vx": "velocity.0",
"vy": "velocity.1",
"vz": "velocity.2",
"Jx": "angularMomentum.0",
"Jy": "angularMomentum.1",
"Jz": "angularMomentum.2",
"Spin": "spin",
}
# Consistent-Trees emits scale radius in *kpc/h* even though all other length
# quantities are Mpc/h. Galacticus expects Mpc/h throughout.
_RS_KPC_TO_MPC = 1.0 / 1000.0
_HEADER_COL_RE = re.compile(r"^([A-Za-z_][A-Za-z0-9_]*)\(\d+\)$")
def _parse_header_columns(line: str) -> dict[str, int]:
"""Parse a ``#scale(0) id(1) ... TrailingName ...`` CT header line.
Two token shapes are recognised:
- ``name(N)`` — explicit 0-based position; stored at index ``N``.
- Plain identifier (no parens) appearing *after* the first indexed
token — inherits the next sequential index. CT builds often tack
extra columns like ``Rs_Klypin``, ``Mvir_all``, ``M200b`` onto
the end of the header without ``(N)`` annotations; this keeps
them addressable by name.
Tokens with parens but not matching ``name(N)`` (e.g.
``b_to_a(500c)``, ``mmp?(14)``) still advance the position
counter — they occupy a slot in the data row — but their names are
not stored.
"""
tokens = line.lstrip("#").split()
cols: dict[str, int] = {}
next_index: int | None = None
for tok in tokens:
m = _HEADER_COL_RE.match(tok)
if m is not None:
idx = int(tok[tok.index("(") + 1 : tok.index(")")])
cols[m.group(1)] = idx
next_index = idx + 1
continue
if next_index is None:
# Pre-indexed garbage (shouldn't happen in well-formed CT).
continue
if "(" not in tok and ")" not in tok:
cols[tok] = next_index
next_index += 1
return cols
def _find_col(cols: dict[str, int], name: str) -> int:
"""Look up a column position by case-insensitive name match.
Different CT versions vary on capitalisation (``Mvir`` vs ``mvir``,
``Rs_Klypin`` vs ``rs_klypin``); the reader's canonical lookups
use lowercase and resolve against whatever the file has.
"""
lower = name.lower()
for key, value in cols.items():
if key.lower() == lower:
return value
raise KeyError(name)
def _parse_cosmology_header(lines: list[str]) -> dict[str, float]:
"""Extract cosmological parameters from CT header comments.
Recognised tokens (case-insensitive variants accepted): ``Omega_M``,
``Omega_L``, ``Omega_b`` / ``Omega_B``, ``h0`` / ``h``, ``sigma_8``
/ ``sigma8``. Mapped to the Galacticus attribute names ``Omega0``,
``OmegaLambda``, ``OmegaBaryon``, ``HubbleParam``, ``sigma_8``.
Anything we don't recognise is silently ignored; users can still
fill those gaps via ``metadata.cosmology``.
"""
out: dict[str, float] = {}
pattern = re.compile(r"\b(Omega_M|Omega_L|Omega_[bB]|h0|h|sigma_?8)\s*=\s*([0-9.eE+-]+)")
key_map = {
"Omega_M": "Omega0",
"Omega_L": "OmegaLambda",
"Omega_b": "OmegaBaryon",
"Omega_B": "OmegaBaryon",
"h0": "HubbleParam",
"h": "HubbleParam",
"sigma_8": "sigma_8",
"sigma8": "sigma_8",
}
for line in lines:
for key, value in pattern.findall(line):
out[key_map[key]] = float(value)
return out
def _parse_box_size(lines: list[str]) -> float | None:
pattern = re.compile(r"Full box size\s*=\s*([0-9.eE+-]+)\s*Mpc")
for line in lines:
m = pattern.search(line)
if m:
return float(m.group(1))
return None
[docs]
class ConsistentTreesReader(TreeReader):
"""Reader for the Consistent-Trees output of the Rockstar pipeline."""
name: ClassVar[str] = "consistent_trees"
aliases: ClassVar[tuple[str, ...]] = ("consistent-trees", "ctrees")
def __init__(self, source: ReaderSource, options: dict[str, Any] | None = None) -> None:
super().__init__(source, options)
self._input_path = Path(source.require("input_path"))
self._forests_path = Path(source.require("forests_path"))
self._locations_path = Path(source.require("locations_path"))
self._host_source: str = self.options.get("host_source", "pid")
self._scale_radius_source: str = self.options.get("scale_radius_source", "rs")
if self._host_source not in ("pid", "upid"):
raise ReaderError(f"host_source must be 'pid' or 'upid', got {self._host_source!r}")
if self._scale_radius_source not in ("rs", "rs_klypin"):
raise ReaderError(
"scale_radius_source must be 'rs' or 'rs_klypin', got "
f"{self._scale_radius_source!r}"
)
self._tree_index: list[_TreeRef] | None = None
self._forest_index: list[_ForestRef] | None = None
self._header_cache: dict[Path, dict[str, int]] = {}
self._cosmo_cache: dict[str, float] | None = None
self._box_size_cache: float | None = None
# ------------------------------------------------------------ public API
[docs]
def defaults(self) -> Metadata:
"""The four /forestHalos flags the legacy C tool always emitted.
These match the values in the original parameter.cfg shipped with
rockstar2galacticus. They're true for any standard Rockstar /
Consistent-Trees run; users with different conventions can
override per-key via ``metadata.haloTrees`` in their YAML.
"""
return Metadata(
halo_trees={
"haloMassesIncludeSubhalos": 1,
"forestsAreSelfContained": 1,
"treesHaveSubhalos": 1,
"velocitiesIncludeHubbleFlow": 0,
}
)
def __len__(self) -> int:
self._ensure_indexed()
assert self._forest_index is not None
return len(self._forest_index)
def __iter__(self) -> Iterator[Forest]:
self._ensure_indexed()
assert self._forest_index is not None
for forest_ref in self._forest_index:
yield self._load_forest(forest_ref)
# ----------------------------------------------------------- indexing
def _ensure_indexed(self) -> None:
if self._forest_index is not None:
return
tree_to_forest = self._read_forests_list()
trees = self._read_locations(tree_to_forest)
self._tree_index = trees
self._forest_index = self._group_trees_into_forests(trees)
self._cache_header_metadata(trees)
def _read_forests_list(self) -> dict[int, _ForestEntry]:
out: dict[int, _ForestEntry] = {}
with self._forests_path.open() as f:
for raw in f:
line = raw.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
try:
tree_root_id = int(parts[0])
except ValueError:
continue # header line
forest_id = int(parts[1])
weight = float(parts[2]) if len(parts) >= 3 else 1.0
out[tree_root_id] = _ForestEntry(forest_id, weight)
return out
def _read_locations(self, tree_to_forest: dict[int, _ForestEntry]) -> list[_TreeRef]:
trees: list[_TreeRef] = []
with self._locations_path.open() as f:
for raw in f:
line = raw.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) < 4:
continue
try:
tree_root_id = int(parts[0])
except ValueError:
continue # header line
entry = tree_to_forest.get(tree_root_id)
if entry is None:
raise ReaderError(
f"tree_root_id {tree_root_id} in locations.dat has no "
"matching entry in forests.list"
)
trees.append(
_TreeRef(
tree_root_id=tree_root_id,
forest_id=entry.forest_id,
forest_weight=entry.weight,
file_id=int(parts[1]),
offset=int(parts[2]),
file_path=self._input_path / parts[3],
)
)
return trees
@staticmethod
def _group_trees_into_forests(trees: list[_TreeRef]) -> list[_ForestRef]:
forests: dict[int, _ForestRef] = {}
for t in trees:
fr = forests.get(t.forest_id)
if fr is None:
fr = _ForestRef(forest_id=t.forest_id, weight=t.forest_weight, trees=[])
forests[t.forest_id] = fr
fr.trees.append(t)
return list(forests.values())
def _cache_header_metadata(self, trees: list[_TreeRef]) -> None:
if not trees:
return
first_file = trees[0].file_path
header_lines: list[str] = []
with first_file.open() as fh:
for _ in range(80):
line = fh.readline()
if not line or not line.startswith("#"):
break
header_lines.append(line)
self._cosmo_cache = _parse_cosmology_header(header_lines)
self._box_size_cache = _parse_box_size(header_lines)
def _column_map(self, file_path: Path) -> dict[str, int]:
if file_path in self._header_cache:
return self._header_cache[file_path]
with file_path.open() as fh:
header_line: str | None = None
for line in fh:
if line.startswith("#") and "(0)" in line and "id(" in line:
header_line = line
break
if header_line is None:
raise ReaderError(f"Could not find column header in {file_path}")
cols = _parse_header_columns(header_line)
self._header_cache[file_path] = cols
return cols
# ----------------------------------------------------------- loading
def _load_forest(self, forest_ref: _ForestRef) -> Forest:
"""Materialise one forest, opening each tree file at most once.
Trees in a forest commonly live in the same ``tree_*.dat``
file; iterating per-tree would reopen it N times. We group
trees by file, open each file once, read its trees in offset
order (sequential I/O), then restore the original
``forests.list`` order in the concatenated output.
"""
if not forest_ref.trees:
return Forest(
forest_id=forest_ref.forest_id,
halos=np.empty(0, dtype=HALO_DTYPE),
weight=forest_ref.weight,
)
trees_by_file: dict[Path, list[tuple[int, _TreeRef]]] = {}
for i, tree in enumerate(forest_ref.trees):
trees_by_file.setdefault(tree.file_path, []).append((i, tree))
indexed_halos: list[tuple[int, np.ndarray]] = []
for file_path, indexed_trees in trees_by_file.items():
# Process in offset order to keep file reads sequential.
indexed_trees.sort(key=lambda it: it[1].offset)
with file_path.open() as fh:
for i, tree in indexed_trees:
indexed_halos.append((i, self._load_tree(fh, tree)))
# Restore original tree order before concatenating.
indexed_halos.sort(key=lambda ih: ih[0])
halos_per_tree = [halos for _, halos in indexed_halos]
halos = np.concatenate(halos_per_tree)
return Forest(
forest_id=forest_ref.forest_id,
halos=halos,
weight=forest_ref.weight,
)
def _resolve_columns_for(self, tree: _TreeRef) -> dict[str, int]:
"""Map every column the reader needs to its 0-based index in the
tree file. Raises :class:`ReaderError` on any missing field with a
message that names the column."""
cols = self._column_map(tree.file_path)
required = (
"scale",
"id",
"desc_id",
"mvir",
"x",
"y",
"z",
"vx",
"vy",
"vz",
"Jx",
"Jy",
"Jz",
"Spin",
)
out: dict[str, int] = {}
for name in required:
try:
out[name] = _find_col(cols, name)
except KeyError as exc:
raise ReaderError(
f"Tree file {tree.file_path} is missing required column {name!r}"
) from exc
try:
out["host"] = _find_col(cols, self._host_source)
except KeyError as exc:
raise ReaderError(
f"Tree file {tree.file_path} has no {self._host_source!r} column "
"for host_source; pick a different host_source."
) from exc
rs_key = "rs_klypin" if self._scale_radius_source == "rs_klypin" else "rs"
try:
out["rs"] = _find_col(cols, rs_key)
except KeyError as exc:
raise ReaderError(
f"Tree file {tree.file_path} has no {rs_key!r} column; "
"set options.scale_radius_source accordingly."
) from exc
return out
def _load_tree(self, fh: TextIO, tree: _TreeRef) -> np.ndarray:
"""Parse one tree's halos from an already-open file handle.
Caller is responsible for opening / closing ``fh``; this method
seeks to ``tree.offset`` and reads forward until the next ``#``
line or EOF.
"""
ci = self._resolve_columns_for(tree)
i_scale = ci["scale"]
i_id = ci["id"]
i_desc_id = ci["desc_id"]
i_mvir = ci["mvir"]
i_x, i_y, i_z = ci["x"], ci["y"], ci["z"]
i_vx, i_vy, i_vz = ci["vx"], ci["vy"], ci["vz"]
i_jx, i_jy, i_jz = ci["Jx"], ci["Jy"], ci["Jz"]
i_spin = ci["Spin"]
i_host = ci["host"]
i_rs = ci["rs"]
# Read raw rows starting at the tree's byte offset, stopping at the
# next ``#`` line or EOF.
rows: list[list[str]] = []
fh.seek(tree.offset)
for line in fh:
if not line.strip():
continue
if line.startswith("#"):
break
rows.append(line.split())
if not rows:
return np.empty(0, dtype=HALO_DTYPE)
n = len(rows)
halos = np.empty(n, dtype=HALO_DTYPE)
for k, row in enumerate(rows):
halos["nodeIndex"][k] = int(row[i_id])
halos["descendantIndex"][k] = int(row[i_desc_id])
host = int(row[i_host])
halos["hostIndex"][k] = host
halos["expansionFactor"][k] = float(row[i_scale])
halos["nodeMass"][k] = float(row[i_mvir])
halos["scaleRadius"][k] = float(row[i_rs]) * _RS_KPC_TO_MPC
halos["halfMassRadius"][k] = np.nan
halos["position"][k, 0] = float(row[i_x])
halos["position"][k, 1] = float(row[i_y])
halos["position"][k, 2] = float(row[i_z])
halos["velocity"][k, 0] = float(row[i_vx])
halos["velocity"][k, 1] = float(row[i_vy])
halos["velocity"][k, 2] = float(row[i_vz])
halos["angularMomentum"][k, 0] = float(row[i_jx])
halos["angularMomentum"][k, 1] = float(row[i_jy])
halos["angularMomentum"][k, 2] = float(row[i_jz])
halos["spin"][k] = float(row[i_spin])
# Galacticus convention: nodes with no parent have hostIndex==nodeIndex,
# not -1.
no_host = halos["hostIndex"] == -1
halos["hostIndex"][no_host] = halos["nodeIndex"][no_host]
return halos