Source code for prt_sim.jhu.registry

from __future__ import annotations
from dataclasses import dataclass
from importlib import import_module
from importlib.metadata import entry_points
from typing import Any, Callable, Dict, Optional

# In-memory registry (runtime registration support)
_REGISTRY: Dict[str, str] = {}
_DISCOVERED: bool = False

[docs] @dataclass(frozen=True) class Spec: id: str entry_point: str # "pkg.module:callable"
[docs] def _discover_entry_points(group: str = "jhu.envs") -> None: """Load installed entry points one time into _REGISTRY.""" global _DISCOVERED if _DISCOVERED: return eps = entry_points() group_eps = eps.select(group=group) if hasattr(eps, "select") else eps.get(group, []) for ep in group_eps: # ep.name is the ID, ep.value is "module:attr" (older pkg resources) OR # ep.module + ep.attr if using new API. Handle both: try: value = getattr(ep, "value", None) or f"{ep.module}:{ep.attr}" except Exception: # Fallback to 'value' for older metadata value = ep.value # type: ignore[attr-defined] if ep.name not in _REGISTRY: # don't clobber runtime registrations _REGISTRY[ep.name] = value _DISCOVERED = True
[docs] def register(id: str, entry_point: str) -> None: """ Register at runtime (e.g., in tests or plugins): register("JHU/MyEnv-v0", "my_pkg.my_mod:MyEnvClass") """ _REGISTRY[id] = entry_point
[docs] def _load_callable(spec: str) -> Callable[..., Any]: """ Import and return the callable referenced by "module:attr". """ if ":" not in spec: raise ValueError(f"Entry point must be 'module:attr', got: {spec!r}") module_name, attr = spec.split(":", 1) mod = import_module(module_name) try: fn = getattr(mod, attr) except AttributeError: raise ImportError(f"Cannot find attribute {attr!r} in module {module_name!r}") if not callable(fn): raise TypeError(f"Entry point target must be callable, got: {type(fn)} from {spec}") return fn
[docs] def make(id: str, /, **kwargs: Any) -> Any: """ Create an instance by string ID. Works with classes or factory functions. Example: env = jhu.make("JHU/ImagePipeline-v0", width=64, height=64) """ _discover_entry_points() spec = _REGISTRY.get(id) if spec is None: # Helpful diagnostics available = ", ".join(sorted(_REGISTRY.keys())) or "<none>" raise KeyError(f"Unknown JHU id {id!r}. Available: {available}") ctor = _load_callable(spec) return ctor(**kwargs)
[docs] def specs(prefix: Optional[str] = None) -> Dict[str, Spec]: """List all registered specs, optionally filtered by ID prefix.""" _discover_entry_points() out = {} for k, v in _REGISTRY.items(): if prefix is None or k.startswith(prefix): out[k] = Spec(id=k, entry_point=v) return out