Source code for astrosylva.readers.ahf

"""AHF (Amiga Halo Finder) reader.

AHF emits one ``.AHF_halos`` catalogue per snapshot plus ``.AHF_mtree``
(or ``.AHF_mtree_idx``) files linking halos across snapshots. To produce
merger trees in Galacticus format we walk pairs of snapshots, stitch
descendant pointers, and partition the resulting halo set into forests
via union-find on descendant + host edges.

Source keys
-----------

- ``snapshots`` : list of ``{halos: path, mtree: path | null, a: float}``
  ordered from earliest to latest. The last snapshot has ``mtree: null``.

mtree formats
-------------

Two AHF mtree variants are recognised:

- ``idx``   — two columns per line: ``ProgID DescID``  (the
              ``.AHF_mtree_idx`` format).
- ``block`` — blocks of ``DescID HaloPart NumProgenitors`` followed by
              ``NumProgenitors`` lines of ``SharedPart ProgID HaloPart``
              (the standard ``.AHF_mtree`` format).

``options.mtree_format`` selects ``"auto"`` (default), ``"idx"``, or
``"block"``. Auto-detect picks ``idx`` when every record is two
tokens wide and ``block`` when every record is at least three; mixed
widths raise.

Units
-----

Standard AHF emits positions in Mpc/h, Rvir in kpc/h, and Mvir in
M_sun/h; the defaults match. Override per build via ``options.units``:

.. code-block:: yaml

   options:
     units:
       position: kpc/h   # default Mpc/h
       radius:   Mpc/h   # default kpc/h
       mass:     Msun    # default Msun/h; requires options.hubble_h
     hubble_h: 0.704     # only needed when units.mass = "Msun"

Supported length units: ``Mpc/h``, ``kpc/h``. Supported mass units:
``Msun/h`` (default), ``Msun`` (physical solar masses; the reader
multiplies by ``options.hubble_h`` to land in canonical M_sun/h).
Velocity is unconditional km/s — the AHF builds we've seen all agree
on that.

Status: experimental — column layout is auto-detected from the
``.AHF_halos`` header (``#name(1) name(2) ...``) and falls back to the
defaults below when the file has no recognisable header. Override
individual columns via ``options.columns``. Halo IDs are expected to be
globally unique across snapshots (standard AHF behaviour).
"""

from __future__ import annotations

import re
from collections.abc import Iterator
from pathlib import Path
from typing import Any, ClassVar

import numpy as np

from astrosylva.exceptions import ReaderError
from astrosylva.readers._forests import clamp_hosts_to_forest, group_by_union_find
from astrosylva.readers.base import TreeReader
from astrosylva.schema import DEFAULT_UNITS, HALO_DTYPE, Forest, Metadata

# 0-based positions of the columns we use in a common AHF .AHF_halos layout.
# Real AHF builds vary, so this is only a fallback — auto-detected header
# values and ``options.columns`` overrides take precedence.
_AHF_DEFAULT_COLUMNS: dict[str, int] = {
    "ID": 0,
    "hostHalo": 1,
    "Mvir": 3,
    "Rvir": 11,
    "Xc": 5,
    "Yc": 6,
    "Zc": 7,
    "VXc": 8,
    "VYc": 9,
    "VZc": 10,
    "Lx": 21,
    "Ly": 22,
    "Lz": 23,
    "lambda": 20,
}
_AHF_REQUIRED_COLUMNS = frozenset(_AHF_DEFAULT_COLUMNS)

_AHF_HEADER_TOKEN = re.compile(r"^([A-Za-z_][A-Za-z0-9_]*)\((\d+)\)")


def _parse_ahf_header(line: str) -> dict[str, int]:
    """Parse ``#name(1) other(2) ...`` into a ``{name: 0-based-index}`` map.

    Tokens that don't match the ``name(N)`` shape (e.g. annotated with
    units like ``Mvir(4)[Msun/h]``) are silently skipped — the user can
    fill those gaps with ``options.columns``.
    """
    out: dict[int, str] = {}
    tokens = line.lstrip("#").split()
    for tok in tokens:
        m = _AHF_HEADER_TOKEN.match(tok)
        if m is None:
            continue
        pos = int(m.group(2)) - 1  # AHF headers are 1-based
        out[pos] = m.group(1)
    return {name: pos for pos, name in out.items()}


_VALID_MTREE_FORMATS = ("auto", "idx", "block")

# Multipliers from each supported length unit to the canonical Mpc/h.
_LENGTH_UNIT_TO_MPCH: dict[str, float] = {
    "kpc/h": 1e-3,
    "Mpc/h": 1.0,
}

# Mass units. ``Msun`` requires ``options.hubble_h`` since converting
# physical M_sun back to M_sun/h needs multiplying by h.
_MASS_UNITS = ("Msun/h", "Msun")

# Per-quantity defaults. AHF's standard build emits positions in Mpc/h,
# Rvir in kpc/h, and Mvir in M_sun/h; override any via ``options.units``.
_DEFAULT_UNITS: dict[str, str] = {
    "position": "Mpc/h",
    "radius": "kpc/h",
    "mass": "Msun/h",
}


def _length_scale_factor(value: str, *, key: str) -> float:
    if value not in _LENGTH_UNIT_TO_MPCH:
        valid = sorted(_LENGTH_UNIT_TO_MPCH)
        raise ReaderError(f"Unknown length unit {value!r} for {key}; expected one of {valid}")
    return _LENGTH_UNIT_TO_MPCH[value]


[docs] class AHFReader(TreeReader): """Reader for AHF halo catalogues + merger-tree files.""" name: ClassVar[str] = "ahf" aliases: ClassVar[tuple[str, ...]] = ("amiga",) def __init__(self, source: Any, options: dict[str, Any] | None = None) -> None: super().__init__(source, options) self._mtree_format: str = self.options.get("mtree_format", "auto") if self._mtree_format not in _VALID_MTREE_FORMATS: raise ReaderError( f"mtree_format must be one of 'auto', 'idx', 'block'; got {self._mtree_format!r}" ) overrides = self.options.get("columns", {}) or {} if not isinstance(overrides, dict): raise ReaderError( "options.columns must be a mapping of column-name to 0-based index; " f"got {type(overrides).__name__}" ) try: self._column_overrides: dict[str, int] = {str(k): int(v) for k, v in overrides.items()} except (TypeError, ValueError) as exc: raise ReaderError("options.columns values must be integer column indices") from exc units = self.options.get("units", {}) or {} if not isinstance(units, dict): raise ReaderError( "options.units must be a mapping of quantity to unit string; " f"got {type(units).__name__}" ) unknown = set(units.keys()) - set(_DEFAULT_UNITS.keys()) if unknown: raise ReaderError( f"Unknown options.units keys: {sorted(unknown)}; " f"expected a subset of {sorted(_DEFAULT_UNITS)}" ) self._position_scale = _length_scale_factor( units.get("position", _DEFAULT_UNITS["position"]), key="units.position" ) self._radius_scale = _length_scale_factor( units.get("radius", _DEFAULT_UNITS["radius"]), key="units.radius" ) self._mass_scale = self._resolve_mass_scale(units.get("mass", _DEFAULT_UNITS["mass"])) def _resolve_mass_scale(self, value: str) -> float: """Multiplier that converts the source mass unit to canonical M_sun/h. ``Msun/h`` passes through (factor 1.0). ``Msun`` requires ``options.hubble_h`` so we can multiply by h. """ if value not in _MASS_UNITS: raise ReaderError( f"Unknown mass unit {value!r} for units.mass; expected one of {list(_MASS_UNITS)}" ) if value == "Msun/h": return 1.0 # value == "Msun" if "hubble_h" not in self.options: raise ReaderError( "options.units.mass='Msun' requires options.hubble_h " "so the reader can convert M_sun -> M_sun/h." ) try: h = float(self.options["hubble_h"]) except (TypeError, ValueError) as exc: raise ReaderError( f"options.hubble_h must be a number; got {self.options['hubble_h']!r}" ) from exc if h <= 0: raise ReaderError(f"options.hubble_h must be > 0; got {h}") return h
[docs] def metadata(self) -> Metadata: return Metadata(units=dict(DEFAULT_UNITS))
def __len__(self) -> int: self._ensure_loaded() assert self._forests is not None return len(self._forests) def __iter__(self) -> Iterator[Forest]: self._ensure_loaded() assert self._forests is not None yield from self._forests def _ensure_loaded(self) -> None: if getattr(self, "_forests", None) is not None: return snapshots: list[dict[str, Any]] = self.source.require("snapshots") if not snapshots: raise ReaderError("AHF reader requires at least one snapshot") # Load halos snapshot-by-snapshot, oldest first. per_snap: list[np.ndarray] = [] for snap in snapshots: per_snap.append(self._load_halo_catalogue(Path(snap["halos"]), float(snap["a"]))) # Build descendant links via .AHF_mtree (current -> next snapshot). for i, snap in enumerate(snapshots[:-1]): mtree_path = snap.get("mtree") if mtree_path is None: continue self._apply_mtree(Path(mtree_path), per_snap[i]) if not per_snap: self._forests = [] return halos = np.concatenate(per_snap) # Partition into self-contained forests via union-find on # descendant + host edges. ``root_desc`` is just nodeIndex # because AHF has no separate root-descendant concept. forest_index = group_by_union_find( halos["nodeIndex"], halos["nodeIndex"], halos["descendantIndex"], halos["hostIndex"], ) forests: list[Forest] = [] for forest_id, indices in forest_index.items(): forest_halos = halos[indices].copy() forest_halos["hostIndex"] = clamp_hosts_to_forest( forest_halos["hostIndex"], forest_halos["nodeIndex"] ) forests.append(Forest(forest_id=forest_id, halos=forest_halos)) self._forests = forests def _load_halo_catalogue(self, path: Path, a: float) -> np.ndarray: if not path.is_file(): raise ReaderError(f"AHF halos file not found: {path}") header_columns, rows = _read_ahf_halos_file(path) row_width = len(rows[0]) if rows else 0 column_map = self._resolve_columns(header_columns, path, row_width) max_index = max(column_map.values()) if column_map else -1 n = len(rows) halos = np.empty(n, dtype=HALO_DTYPE) c = column_map for k, row in enumerate(rows): if len(row) <= max_index: raise ReaderError( f"AHF row {k} in {path} has {len(row)} fields but column " f"layout needs at least {max_index + 1}." ) halos["nodeIndex"][k] = int(row[c["ID"]]) halos["descendantIndex"][k] = -1 halos["hostIndex"][k] = int(row[c["hostHalo"]]) halos["expansionFactor"][k] = a halos["nodeMass"][k] = float(row[c["Mvir"]]) * self._mass_scale halos["scaleRadius"][k] = float(row[c["Rvir"]]) * self._radius_scale halos["halfMassRadius"][k] = np.nan # AHF default output lacks half-mass radius halos["position"][k, 0] = float(row[c["Xc"]]) * self._position_scale halos["position"][k, 1] = float(row[c["Yc"]]) * self._position_scale halos["position"][k, 2] = float(row[c["Zc"]]) * self._position_scale halos["velocity"][k, 0] = float(row[c["VXc"]]) halos["velocity"][k, 1] = float(row[c["VYc"]]) halos["velocity"][k, 2] = float(row[c["VZc"]]) halos["angularMomentum"][k, 0] = float(row[c["Lx"]]) halos["angularMomentum"][k, 1] = float(row[c["Ly"]]) halos["angularMomentum"][k, 2] = float(row[c["Lz"]]) halos["spin"][k] = float(row[c["lambda"]]) # Galacticus convention: no-host -> self. no_host = halos["hostIndex"] == 0 halos["hostIndex"][no_host] = halos["nodeIndex"][no_host] return halos def _resolve_columns( self, header_columns: dict[str, int], path: Path, row_width: int, ) -> dict[str, int]: """Merge defaults, header-detected, and user-supplied column maps. Precedence (lowest → highest): defaults → header → user overrides. Any entry pointing past the file's actual row width is treated as absent (the default may be a phantom that the user's build doesn't actually emit). Raises if any required column is still missing. """ column_map: dict[str, int] = dict(_AHF_DEFAULT_COLUMNS) column_map.update(header_columns) column_map.update(self._column_overrides) if row_width > 0: column_map = {name: idx for name, idx in column_map.items() if idx < row_width} missing = sorted(_AHF_REQUIRED_COLUMNS - column_map.keys()) if missing: raise ReaderError( f"AHF file {path} is missing required columns {missing}; " "supply them via options.columns." ) return column_map def _apply_mtree(self, path: Path, current: np.ndarray) -> None: if not path.is_file(): raise ReaderError(f"AHF mtree file not found: {path}") records = _read_mtree_records(path) if not records: return fmt = self._resolve_mtree_format(records, path) id_to_idx = {int(h): i for i, h in enumerate(current["nodeIndex"])} if fmt == "idx": _apply_mtree_idx(records, id_to_idx, current, path) else: _apply_mtree_block(records, id_to_idx, current, path) def _resolve_mtree_format(self, records: list[list[str]], path: Path) -> str: if self._mtree_format != "auto": return self._mtree_format widths = {len(r) for r in records} if widths == {2}: return "idx" if min(widths) >= 3: return "block" raise ReaderError( f"AHF mtree at {path} has inconsistent record widths {sorted(widths)}; " "set options.mtree_format explicitly ('idx' or 'block')." )
def _read_ahf_halos_file(path: Path) -> tuple[dict[str, int], list[list[str]]]: """Parse the column header (from any ``#name(N) ...`` line) and rows. Returns ``({name: 0-based-index}, rows)``. Header tokens accumulate across multiple comment lines — useful for AHF outputs that wrap the column list. An empty header dict means "no usable header found". """ header: dict[str, int] = {} rows: list[list[str]] = [] with path.open() as fh: for raw in fh: stripped = raw.strip() if not stripped: continue if stripped.startswith("#"): header.update(_parse_ahf_header(stripped)) continue rows.append(stripped.split()) return header, rows def _read_mtree_records(path: Path) -> list[list[str]]: out: list[list[str]] = [] with path.open() as fh: for raw in fh: stripped = raw.strip() if not stripped or stripped.startswith("#"): continue out.append(stripped.split()) return out def _apply_mtree_idx( records: list[list[str]], id_to_idx: dict[int, int], current: np.ndarray, path: Path, ) -> None: """Apply ``ProgID DescID`` pairs from an .AHF_mtree_idx-style file.""" for parts in records: if len(parts) != 2: raise ReaderError( f"AHF idx-format mtree expected 2 fields per line, got {parts!r} in {path}" ) try: prog_id, desc_id = int(parts[0]), int(parts[1]) except ValueError as exc: raise ReaderError(f"Non-integer field in {path}: {parts!r}") from exc idx = id_to_idx.get(prog_id) if idx is not None: current["descendantIndex"][idx] = desc_id def _apply_mtree_block( records: list[list[str]], id_to_idx: dict[int, int], current: np.ndarray, path: Path, ) -> None: """Apply descendant pointers from a block-format ``.AHF_mtree``. Each block is a 3-field header ``DescID HaloPart NumProgenitors`` followed by ``NumProgenitors`` 3-field progenitor lines ``SharedPart ProgID HaloPart``. """ i = 0 while i < len(records): header = records[i] if len(header) < 3: raise ReaderError( f"AHF block-format mtree header expects >=3 fields: {header!r} in {path}" ) try: desc_id = int(header[0]) n_prog = int(header[2]) except ValueError as exc: raise ReaderError( f"AHF block-format mtree header not parseable: {header!r} in {path}" ) from exc if n_prog < 0: raise ReaderError( f"AHF block-format mtree NumProgenitors negative: {header!r} in {path}" ) i += 1 for _ in range(n_prog): if i >= len(records): raise ReaderError( f"Truncated AHF block-format mtree near descendant {desc_id} in {path}" ) prog_line = records[i] if len(prog_line) < 2: raise ReaderError( f"AHF block-format progenitor line expects >=2 fields: {prog_line!r} in {path}" ) try: prog_id = int(prog_line[1]) except ValueError as exc: raise ReaderError(f"Non-integer progenitor ID in {path}: {prog_line!r}") from exc idx = id_to_idx.get(prog_id) if idx is not None: current["descendantIndex"][idx] = desc_id i += 1