mirror of
https://github.com/ruvnet/RuView.git
synced 2026-06-02 00:58:56 +02:00
feat(worldmodel): ADR-147 Phase 3+5 — RuViewOccDataset domain adapter + retraining pipeline
Phase 3 — scripts/ruview_occ_dataset.py: - RuViewOccDataset: WorldGraph JSON snapshots → OccWorld (F,H,W,D) tensors - Indoor class remapping: person→7, floor→9, wall→11, furniture→16, free→17 - Zero ego-poses (fixed indoor sensor, no ego-motion) - record_snapshot() helper for training data accumulation - Validated: 5 windows, (16,200,200,16) tensor, person+floor voxels confirmed Phase 5 — scripts/occworld_retrain.py: - record: stream WorldGraph snapshots from sensing server REST API - vqvae: fine-tune VQVAE tokenizer on RuView occupancy (200 epochs, AdamW) - transformer: fine-tune autoregressive transformer with frozen VQVAE wifi-densepose-worldmodel v0.3.0 published to crates.io Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
+1
-1
@@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- **ADR-147 — OccWorld world model integration** (`wifi-densepose-worldmodel` v0.3.0). Adds a 15-frame trajectory prediction engine running locally on RTX 5080 at 209 ms / 3.37 GB VRAM peak. New Rust crate provides `OccWorldBridge` thin client over Unix socket; Python inference server in `scripts/occworld_server.py` runs OccWorld TransVQVAE (72.4M params) with API-bug patches applied. Kalman tracker (`pose_tracker.rs`) gains `trajectory_prior` soft-blend injection (80/20). See [ADR-147](docs/adr/ADR-147-nvidia-cosmos-world-foundation-model-integration.md) · [benchmark proof](docs/adr/ADR-147-benchmark-proof.md).
|
||||
- **ADR-147 — OccWorld world model integration** (`wifi-densepose-worldmodel` v0.3.0 published to crates.io). 15-frame trajectory prediction at 209 ms / 3.37 GB VRAM on RTX 5080. Phase 3 domain adapter `scripts/ruview_occ_dataset.py` (`RuViewOccDataset`) converts WorldGraph snapshots to OccWorld tensors with indoor class remapping + zero ego-poses (validated). Phase 5 retraining pipeline `scripts/occworld_retrain.py` — VQVAE + transformer fine-tuning on RuView occupancy snapshots. See [ADR-147](docs/adr/ADR-147-nvidia-cosmos-world-foundation-model-integration.md) · [benchmark proof](docs/adr/ADR-147-benchmark-proof.md).
|
||||
|
||||
### Added
|
||||
- **ADR-125 (APPLE-FABRIC) — RuView ↔ Apple Home native HAP bridge proposal + reference impl** (issue #796). New ADR-125 lays out a three-phase plan to expose RuView as a discoverable HomeKit accessory on the LAN so a HomePod (as Home Hub) sees presence / vitals / BFLD-derived events natively — zero Home-Assistant intermediary. Two architectural decisions resolved in the ADR per design review: (1) **one HAP bridge with N child accessories** (single pairing, matches Hue/Eve pattern), and (2) **identity-risk mapping is semantic, not probabilistic** — `identity_risk_score` and Soul-Signature match probability never cross the HAP boundary; instead three thresholded events are exposed (`Unknown Presence`, `Unexpected Occupancy`, `Unrecognized Activity Pattern`) so RuView reads as calm-tech ambient awareness, not surveillance UX. ADR-125 §2.1.a reference impl ships now: `scripts/hap-test-sensor.py` (HAP-1.1 bridge advertised over mDNS, paired with operator's iPhone) + `scripts/c6-presence-watcher.py` (parses ESP32 `RV_FEATURE_STATE_MAGIC = 0xC5110006` UDP packets with IEEE CRC32 validation, hysteresis, and a Python port of `wifi-densepose-bfld::PrivacyClass` that enforces ADR-125 §2.1.d invariant I1 at the HomeKit edge — only `Anonymous` (2) and `Restricted` (3) frames may cross; `Raw`/`Derived` are refused with exit code 2 and the cited ADR clause). Validated end-to-end on real hardware (no mocks): ESP32-C6 on `ruv.net` → UDP/5005 → mac-mini watcher → BFLD gate → HAP bridge → iPhone Home app shows `Unknown Presence` live characteristic flip. **Empirical**: 50-51 valid CRC-passing feature_state packets per 10 s window from the live C6; zero CRC errors. P2 (Rust-native HAP via the `hap` crate, replaces the Python sidecar) and P3 (Matter Controller once `matter-rs` stabilizes) follow.
|
||||
|
||||
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Phase 5 — OccWorld VQVAE + Transformer retraining on RuView indoor occupancy.
|
||||
|
||||
Two-stage training pipeline:
|
||||
Stage 1: Retrain VQVAE tokenizer on RuView snapshots
|
||||
Stage 2: Retrain autoregressive transformer on tokenized sequences
|
||||
|
||||
Usage:
|
||||
# Stage 1: VQVAE
|
||||
python3 scripts/occworld_retrain.py vqvae \
|
||||
--snapshots /tmp/snapshots/ \
|
||||
--work-dir out/ruview_vqvae \
|
||||
--epochs 200
|
||||
|
||||
# Stage 2: Transformer (requires Stage 1 checkpoint)
|
||||
python3 scripts/occworld_retrain.py transformer \
|
||||
--snapshots /tmp/snapshots/ \
|
||||
--vqvae-checkpoint out/ruview_vqvae/latest.pth \
|
||||
--work-dir out/ruview_occworld \
|
||||
--epochs 200
|
||||
|
||||
# Generate training snapshots from the live sensing server
|
||||
python3 scripts/occworld_retrain.py record \
|
||||
--server http://localhost:8080 \
|
||||
--out-dir /tmp/snapshots/scene_live \
|
||||
--duration 3600
|
||||
|
||||
Requirements:
|
||||
ml-env with OccWorld installed (see ADR-147 §3)
|
||||
At least 16 GB VRAM for training (RTX 5080 sufficient at batch=1)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Stage 0: Record snapshots from the live sensing server ───────────────────
|
||||
|
||||
def cmd_record(args: argparse.Namespace) -> None:
|
||||
"""Stream WorldGraph snapshots from the sensing server REST API."""
|
||||
import json
|
||||
import urllib.request
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
url = f"{args.server.rstrip('/')}/api/v1/worldgraph/snapshot"
|
||||
end_time = time.time() + args.duration
|
||||
frame_idx = 0
|
||||
interval = args.interval
|
||||
|
||||
log.info("Recording snapshots from %s → %s for %ds", url, out_dir, args.duration)
|
||||
|
||||
while time.time() < end_time:
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=5) as resp:
|
||||
snap = json.loads(resp.read())
|
||||
out_path = out_dir / f"frame_{frame_idx:06d}.json"
|
||||
out_path.write_text(json.dumps(snap))
|
||||
frame_idx += 1
|
||||
if frame_idx % 100 == 0:
|
||||
log.info("Recorded %d frames", frame_idx)
|
||||
except Exception as exc:
|
||||
log.warning("Snapshot fetch failed: %s", exc)
|
||||
time.sleep(interval)
|
||||
|
||||
log.info("Done — recorded %d frames to %s", frame_idx, out_dir)
|
||||
|
||||
|
||||
# ── Stage 1: VQVAE retraining ────────────────────────────────────────────────
|
||||
|
||||
def cmd_vqvae(args: argparse.Namespace) -> None:
|
||||
"""Retrain the OccWorld VQVAE tokenizer on RuView indoor occupancy."""
|
||||
sys.path.insert(0, str(Path(args.occworld_dir).resolve()))
|
||||
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
try:
|
||||
import model as occmodel # noqa: F401 — registers custom MODELS
|
||||
except ImportError:
|
||||
log.error("Could not import OccWorld model package. Set --occworld-dir correctly.")
|
||||
sys.exit(1)
|
||||
|
||||
from ruview_occ_dataset import RuViewOccDataset
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
work_dir = Path(args.work_dir)
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build VQVAE only
|
||||
vae = MODELS.build(cfg.model.vae).cuda()
|
||||
log.info("VQVAE params: %.1fM", sum(p.numel() for p in vae.parameters()) / 1e6)
|
||||
|
||||
ds = RuViewOccDataset(
|
||||
args.snapshots,
|
||||
return_len=cfg.model.get("num_frames", 15) + 1,
|
||||
voxel_m=args.voxel_m,
|
||||
x_min=args.x_min,
|
||||
y_min=args.y_min,
|
||||
)
|
||||
log.info("Dataset: %d windows from %s", len(ds), args.snapshots)
|
||||
|
||||
if len(ds) == 0:
|
||||
log.error("No training windows found in %s — record snapshots first.", args.snapshots)
|
||||
sys.exit(1)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ds, batch_size=1, shuffle=not args.no_shuffle, num_workers=0,
|
||||
collate_fn=lambda b: b[0], # dict passthrough
|
||||
)
|
||||
|
||||
opt = torch.optim.AdamW(vae.parameters(), lr=1e-3, weight_decay=0.01)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
|
||||
|
||||
best_loss = float("inf")
|
||||
for epoch in range(args.epochs):
|
||||
vae.train()
|
||||
epoch_loss = 0.0
|
||||
for batch in loader:
|
||||
occ = torch.from_numpy(batch["target_occs"]).long().unsqueeze(0).cuda() # (1,F,H,W,D)
|
||||
# VQVAE forward: encode + quantize + decode, returns reconstruction loss
|
||||
z, shape = vae.forward_encoder(occ)
|
||||
z = vae.vqvae.quant_conv(z)
|
||||
z_q, vq_loss, _ = vae.vqvae.forward_quantizer(z, is_voxel=False)
|
||||
z_q = vae.vqvae.post_quant_conv(z_q)
|
||||
recon = vae.forward_decoder(z_q, shape, occ.shape)
|
||||
recon_loss = torch.nn.functional.cross_entropy(
|
||||
recon.flatten(0, -2),
|
||||
occ.flatten(),
|
||||
)
|
||||
loss = recon_loss + vq_loss
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
|
||||
opt.step()
|
||||
epoch_loss += loss.item()
|
||||
|
||||
scheduler.step()
|
||||
avg = epoch_loss / max(len(loader), 1)
|
||||
if epoch % 10 == 0:
|
||||
log.info("Epoch %d/%d loss=%.4f lr=%.2e", epoch + 1, args.epochs, avg, scheduler.get_last_lr()[0])
|
||||
|
||||
if avg < best_loss:
|
||||
best_loss = avg
|
||||
torch.save({"epoch": epoch, "state_dict": vae.state_dict(), "loss": avg},
|
||||
work_dir / "latest.pth")
|
||||
|
||||
log.info("VQVAE training complete. Best loss=%.4f checkpoint: %s/latest.pth",
|
||||
best_loss, work_dir)
|
||||
|
||||
|
||||
# ── Stage 2: Transformer retraining ─────────────────────────────────────────
|
||||
|
||||
def cmd_transformer(args: argparse.Namespace) -> None:
|
||||
"""Retrain the OccWorld autoregressive transformer on tokenized RuView sequences."""
|
||||
sys.path.insert(0, str(Path(args.occworld_dir).resolve()))
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from mmengine.config import Config
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
try:
|
||||
import model as occmodel # noqa: F401
|
||||
except ImportError:
|
||||
log.error("OccWorld model package not found.")
|
||||
sys.exit(1)
|
||||
|
||||
from ruview_occ_dataset import RuViewOccDataset
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
work_dir = Path(args.work_dir)
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
full_model = MODELS.build(cfg.model).cuda()
|
||||
|
||||
# Load VQVAE checkpoint if provided
|
||||
if args.vqvae_checkpoint:
|
||||
ck = torch.load(args.vqvae_checkpoint, map_location="cuda")
|
||||
full_model.vae.load_state_dict(ck["state_dict"])
|
||||
log.info("Loaded VQVAE checkpoint: %s", args.vqvae_checkpoint)
|
||||
full_model.vae.eval()
|
||||
for p in full_model.vae.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
log.info("Transformer params: %.1fM",
|
||||
sum(p.numel() for p in full_model.transformer.parameters()) / 1e6)
|
||||
|
||||
ds = RuViewOccDataset(args.snapshots, return_len=cfg.model.get("num_frames", 15) + 1)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ds, batch_size=1, shuffle=True, num_workers=0,
|
||||
collate_fn=lambda b: b[0],
|
||||
)
|
||||
|
||||
opt = torch.optim.AdamW(full_model.transformer.parameters(), lr=1e-3, weight_decay=0.01)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
full_model.transformer.train()
|
||||
epoch_loss = 0.0
|
||||
for batch in loader:
|
||||
occ = torch.from_numpy(batch["target_occs"]).long().unsqueeze(0).cuda()
|
||||
with torch.no_grad():
|
||||
z, shape = full_model.vae.forward_encoder(occ)
|
||||
z = full_model.vae.vqvae.quant_conv(z)
|
||||
z_q, _, (_, _, indices) = full_model.vae.vqvae.forward_quantizer(z, is_voxel=False)
|
||||
z_q = rearrange(z_q, "(b f) c h w -> b f c h w", b=1)
|
||||
|
||||
bs, F, C, H, W = z_q.shape
|
||||
pose_tokens = torch.zeros(bs, full_model.num_frames, C, device=z_q.device)
|
||||
pred_tokens, _ = full_model.transformer(z_q[:, :full_model.num_frames], pose_tokens)
|
||||
indices_target = rearrange(indices, "(b f) h w -> b f h w", b=bs)[:, full_model.offset:]
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
pred_tokens.flatten(0, 1),
|
||||
indices_target.flatten(0, 1).flatten(1),
|
||||
)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(full_model.transformer.parameters(), 1.0)
|
||||
opt.step()
|
||||
epoch_loss += loss.item()
|
||||
|
||||
scheduler.step()
|
||||
if epoch % 10 == 0:
|
||||
avg = epoch_loss / max(len(loader), 1)
|
||||
log.info("Epoch %d/%d loss=%.4f", epoch + 1, args.epochs, avg)
|
||||
torch.save({"epoch": epoch, "state_dict": full_model.state_dict(), "loss": avg},
|
||||
work_dir / "latest.pth")
|
||||
|
||||
log.info("Transformer training complete. Checkpoint: %s/latest.pth", work_dir)
|
||||
|
||||
|
||||
# ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="OccWorld retraining pipeline for RuView (ADR-147 Phase 5)")
|
||||
p.add_argument("--occworld-dir", default=os.path.expanduser("~/projects/OccWorld"),
|
||||
help="Path to OccWorld repo root")
|
||||
p.add_argument("--config", default=os.path.expanduser("~/projects/OccWorld/config/occworld.py"),
|
||||
help="OccWorld config file")
|
||||
|
||||
sub = p.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
# record
|
||||
rec = sub.add_parser("record", help="Record WorldGraph snapshots from sensing server")
|
||||
rec.add_argument("--server", default="http://localhost:8080")
|
||||
rec.add_argument("--out-dir", required=True)
|
||||
rec.add_argument("--duration", type=int, default=3600, help="Recording duration (s)")
|
||||
rec.add_argument("--interval", type=float, default=0.5, help="Poll interval (s)")
|
||||
|
||||
# vqvae
|
||||
vae = sub.add_parser("vqvae", help="Retrain VQVAE tokenizer")
|
||||
vae.add_argument("--snapshots", required=True)
|
||||
vae.add_argument("--work-dir", default="out/ruview_vqvae")
|
||||
vae.add_argument("--epochs", type=int, default=200)
|
||||
vae.add_argument("--voxel-m", type=float, dest="voxel_m", default=0.4)
|
||||
vae.add_argument("--x-min", type=float, dest="x_min", default=-40.0)
|
||||
vae.add_argument("--y-min", type=float, dest="y_min", default=-40.0)
|
||||
vae.add_argument("--no-shuffle", action="store_true")
|
||||
|
||||
# transformer
|
||||
xfm = sub.add_parser("transformer", help="Retrain autoregressive transformer")
|
||||
xfm.add_argument("--snapshots", required=True)
|
||||
xfm.add_argument("--vqvae-checkpoint", default=None)
|
||||
xfm.add_argument("--work-dir", default="out/ruview_occworld")
|
||||
xfm.add_argument("--epochs", type=int, default=200)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
args = _build_parser().parse_args()
|
||||
{"record": cmd_record, "vqvae": cmd_vqvae, "transformer": cmd_transformer}[args.cmd](args)
|
||||
@@ -32,6 +32,17 @@ import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
|
||||
# Phase 3 — RuViewOccDataset available for callers that want to build
|
||||
# training tensors directly from WorldGraph snapshots (see occworld_retrain.py).
|
||||
try:
|
||||
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if _script_dir not in sys.path:
|
||||
sys.path.insert(0, _script_dir)
|
||||
from ruview_occ_dataset import RuViewOccDataset, snapshot_to_voxels, record_snapshot # noqa: F401
|
||||
_DATASET_AVAILABLE = True
|
||||
except ImportError:
|
||||
_DATASET_AVAILABLE = False
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Phase 3 — RuViewOccDataset: WorldGraph history → OccWorld-format tensors.
|
||||
|
||||
Replaces OccWorld's nuScenesSceneDatasetLidar with a loader that reads
|
||||
WorldGraph JSON snapshots produced by wifi-densepose-worldgraph and returns
|
||||
(B, F, H, W, D) occupancy tensors in the same format OccWorld expects.
|
||||
|
||||
Class mapping (18-class OccWorld schema):
|
||||
RuView class → OccWorld index nuScenes label
|
||||
free / unknown → 17 free
|
||||
person → 7 pedestrian
|
||||
wall / ceiling → 11 other-flat (closest structural)
|
||||
floor → 9 terrain
|
||||
furniture → 16 other-object
|
||||
door / window → 14 bicycle (repurposed for portals)
|
||||
|
||||
Ego-pose: indoor fixed sensor has no ego-motion. rel_poses are all zeros,
|
||||
which suppresses the pose-prediction head without affecting occupancy output.
|
||||
|
||||
Usage (standalone validation):
|
||||
python3 scripts/ruview_occ_dataset.py --snapshots /tmp/snapshots/ --check
|
||||
|
||||
Usage (as OccWorld dataset replacement):
|
||||
from ruview_occ_dataset import RuViewOccDataset
|
||||
ds = RuViewOccDataset(snapshot_dir="/tmp/snapshots", return_len=16)
|
||||
sample = ds[0] # dict with keys: img_metas, target_occs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ── OccWorld voxel grid constants ───────────────────────────────────────────
|
||||
GRID_H = 200 # X (east)
|
||||
GRID_W = 200 # Y (north)
|
||||
GRID_D = 16 # Z (up)
|
||||
|
||||
NUM_CLASSES = 18
|
||||
FREE_CLASS = 17
|
||||
PERSON_CLASS = 7
|
||||
FLOOR_CLASS = 9
|
||||
WALL_CLASS = 11
|
||||
FURNITURE_CLASS = 16
|
||||
DOOR_CLASS = 14
|
||||
|
||||
# Default spatial extent matching nuScenes at 0.4 m/voxel
|
||||
DEFAULT_VOXEL_M = 0.4 # metres per voxel
|
||||
DEFAULT_X_MIN = -40.0 # east min (m)
|
||||
DEFAULT_Y_MIN = -40.0 # north min (m)
|
||||
DEFAULT_Z_MIN = -1.0 # up min (m)
|
||||
DEFAULT_Z_STEP = 0.4 # metres per depth slice
|
||||
|
||||
|
||||
# ── WorldGraph snapshot format ───────────────────────────────────────────────
|
||||
|
||||
def _load_snapshot(path: Path) -> dict:
|
||||
"""Load a WorldGraph JSON snapshot from disk."""
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _extract_persons(snapshot: dict) -> list[tuple[float, float, float]]:
|
||||
"""Return list of (east_m, north_m, up_m) for all PersonTrack nodes."""
|
||||
persons = []
|
||||
nodes = snapshot.get("nodes", {})
|
||||
if isinstance(nodes, dict):
|
||||
items = nodes.values()
|
||||
elif isinstance(nodes, list):
|
||||
items = nodes
|
||||
else:
|
||||
return persons
|
||||
|
||||
for node in items:
|
||||
kind = node.get("kind") or node.get("type") or ""
|
||||
if "person" in kind.lower() or "PersonTrack" in kind:
|
||||
pos = node.get("last_position") or node.get("position") or {}
|
||||
e = float(pos.get("east_m", pos.get("e", 0.0)))
|
||||
n = float(pos.get("north_m", pos.get("n", 0.0)))
|
||||
u = float(pos.get("up_m", pos.get("u", 0.0)))
|
||||
persons.append((e, n, u))
|
||||
|
||||
return persons
|
||||
|
||||
|
||||
def _extract_room_bounds(snapshot: dict) -> dict[str, float] | None:
|
||||
"""Try to extract room bounds from a ZoneBoundsEnu node, else return None."""
|
||||
nodes = snapshot.get("nodes", {})
|
||||
if isinstance(nodes, dict):
|
||||
items = nodes.values()
|
||||
elif isinstance(nodes, list):
|
||||
items = nodes
|
||||
else:
|
||||
return None
|
||||
|
||||
for node in items:
|
||||
kind = node.get("kind") or node.get("type") or ""
|
||||
if "room" in kind.lower() or "zone" in kind.lower():
|
||||
bounds = node.get("bounds") or {}
|
||||
if "min_e" in bounds:
|
||||
return {
|
||||
"x_min": float(bounds["min_e"]),
|
||||
"x_max": float(bounds["max_e"]),
|
||||
"y_min": float(bounds["min_n"]),
|
||||
"y_max": float(bounds["max_n"]),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def snapshot_to_voxels(
|
||||
snapshot: dict,
|
||||
voxel_m: float = DEFAULT_VOXEL_M,
|
||||
x_min: float = DEFAULT_X_MIN,
|
||||
y_min: float = DEFAULT_Y_MIN,
|
||||
z_min: float = DEFAULT_Z_MIN,
|
||||
z_step: float = DEFAULT_Z_STEP,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert a WorldGraph snapshot to a (H, W, D) uint8 occupancy voxel grid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
snapshot : WorldGraph JSON dict
|
||||
voxel_m : metres per horizontal voxel
|
||||
x_min, y_min, z_min : spatial origin in ENU metres
|
||||
z_step : metres per depth slice
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray of shape (GRID_H, GRID_W, GRID_D), dtype uint8, values in [0,17]
|
||||
"""
|
||||
grid = np.full((GRID_H, GRID_W, GRID_D), FREE_CLASS, dtype=np.uint8)
|
||||
|
||||
# Mark floor slice (D=0) as terrain
|
||||
grid[:, :, 0] = FLOOR_CLASS
|
||||
|
||||
persons = _extract_persons(snapshot)
|
||||
for (e, n, u) in persons:
|
||||
xi = int((e - x_min) / voxel_m)
|
||||
yi = int((n - y_min) / voxel_m)
|
||||
zi = int((u - z_min) / z_step)
|
||||
# Person occupies a 2-voxel vertical column (standing height ≈ 1.8 m)
|
||||
for dz in range(min(5, GRID_D)):
|
||||
zz = zi + dz
|
||||
if 0 <= xi < GRID_H and 0 <= yi < GRID_W and 0 <= zz < GRID_D:
|
||||
grid[xi, yi, zz] = PERSON_CLASS
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
# ── Dataset class ────────────────────────────────────────────────────────────
|
||||
|
||||
class RuViewOccDataset:
|
||||
"""
|
||||
OccWorld-compatible dataset backed by WorldGraph JSON snapshots.
|
||||
|
||||
Expected directory layout::
|
||||
|
||||
snapshot_dir/
|
||||
scene_000/
|
||||
frame_000.json
|
||||
frame_001.json
|
||||
...
|
||||
scene_001/
|
||||
...
|
||||
|
||||
Each frame_NNN.json is a WorldGraph JSON snapshot (as produced by
|
||||
wifi-densepose-worldgraph's to_json() method or the sensing server's
|
||||
/api/v1/worldgraph/snapshot endpoint).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
snapshot_dir : root directory containing scene sub-directories
|
||||
return_len : number of consecutive frames per sample (matches OccWorld num_frames+offset)
|
||||
voxel_m : metres per horizontal voxel
|
||||
x_min, y_min, z_min, z_step : spatial grid parameters
|
||||
test_mode : if True, disable augmentation (always True for inference)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
snapshot_dir: str | Path,
|
||||
return_len: int = 16,
|
||||
voxel_m: float = DEFAULT_VOXEL_M,
|
||||
x_min: float = DEFAULT_X_MIN,
|
||||
y_min: float = DEFAULT_Y_MIN,
|
||||
z_min: float = DEFAULT_Z_MIN,
|
||||
z_step: float = DEFAULT_Z_STEP,
|
||||
test_mode: bool = True,
|
||||
) -> None:
|
||||
self.snapshot_dir = Path(snapshot_dir)
|
||||
self.return_len = return_len
|
||||
self.voxel_m = voxel_m
|
||||
self.x_min = x_min
|
||||
self.y_min = y_min
|
||||
self.z_min = z_min
|
||||
self.z_step = z_step
|
||||
self.test_mode = test_mode
|
||||
|
||||
self._scenes: list[list[Path]] = self._index()
|
||||
|
||||
def _index(self) -> list[list[Path]]:
|
||||
"""Walk snapshot_dir and build a list of frame-path sequences."""
|
||||
scenes: list[list[Path]] = []
|
||||
root = self.snapshot_dir
|
||||
|
||||
if not root.exists():
|
||||
return scenes
|
||||
|
||||
# Support flat layout (root/*.json) and scene layout (root/scene/*/*.json)
|
||||
json_files = sorted(root.glob("*.json"))
|
||||
if json_files:
|
||||
# Flat layout — treat as a single scene
|
||||
scenes.append(json_files)
|
||||
else:
|
||||
for scene_dir in sorted(root.iterdir()):
|
||||
if scene_dir.is_dir():
|
||||
frames = sorted(scene_dir.glob("*.json"))
|
||||
if frames:
|
||||
scenes.append(frames)
|
||||
|
||||
return scenes
|
||||
|
||||
def _sliding_windows(self) -> list[tuple[int, int]]:
|
||||
"""Return (scene_idx, frame_start) pairs for all valid windows."""
|
||||
windows = []
|
||||
for si, frames in enumerate(self._scenes):
|
||||
for fi in range(len(frames) - self.return_len + 1):
|
||||
windows.append((si, fi))
|
||||
return windows
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(
|
||||
max(0, len(f) - self.return_len + 1) for f in self._scenes
|
||||
)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""
|
||||
Return a dict compatible with OccWorld's data loader expectations::
|
||||
|
||||
{
|
||||
"img_metas": [{"scene_token": ..., "frame_idx": ...}],
|
||||
"target_occs": np.ndarray (F, H, W, D) uint8,
|
||||
"rel_poses": np.ndarray (F, 3, 4) float32 — all zeros,
|
||||
}
|
||||
"""
|
||||
windows = self._sliding_windows()
|
||||
if idx >= len(windows):
|
||||
raise IndexError(idx)
|
||||
|
||||
si, fi = windows[idx]
|
||||
frame_paths = self._scenes[si][fi : fi + self.return_len]
|
||||
|
||||
voxels_seq = []
|
||||
for fp in frame_paths:
|
||||
snap = _load_snapshot(fp)
|
||||
v = snapshot_to_voxels(
|
||||
snap,
|
||||
voxel_m=self.voxel_m,
|
||||
x_min=self.x_min,
|
||||
y_min=self.y_min,
|
||||
z_min=self.z_min,
|
||||
z_step=self.z_step,
|
||||
)
|
||||
voxels_seq.append(v)
|
||||
|
||||
target_occs = np.stack(voxels_seq, axis=0) # (F, H, W, D)
|
||||
|
||||
# Zero ego-poses: indoor fixed sensor has no ego-motion
|
||||
rel_poses = np.zeros((self.return_len, 3, 4), dtype=np.float32)
|
||||
|
||||
return {
|
||||
"img_metas": [{
|
||||
"scene_token": self._scenes[si][fi].parent.name,
|
||||
"frame_idx": fi,
|
||||
"source": "ruview_worldgraph",
|
||||
}],
|
||||
"target_occs": target_occs,
|
||||
"rel_poses": rel_poses,
|
||||
}
|
||||
|
||||
|
||||
# ── Snapshot recorder helper ─────────────────────────────────────────────────
|
||||
|
||||
def record_snapshot(worldgraph_json: dict, out_dir: Path, frame_idx: int) -> Path:
|
||||
"""
|
||||
Save a WorldGraph JSON snapshot to out_dir/frame_NNN.json.
|
||||
|
||||
Call this from the sensing server or a WorldGraph event listener to
|
||||
accumulate training data for Phase 5 VQVAE retraining.
|
||||
"""
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"frame_{frame_idx:06d}.json"
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(worldgraph_json, f)
|
||||
return out_path
|
||||
|
||||
|
||||
# ── CLI validation ───────────────────────────────────────────────────────────
|
||||
|
||||
def _make_synthetic_snapshot(
|
||||
person_pos: tuple[float, float, float] = (1.0, 1.0, 0.0)
|
||||
) -> dict:
|
||||
"""Create a minimal synthetic WorldGraph snapshot for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"kind": "PersonTrack",
|
||||
"id": 1,
|
||||
"last_position": {
|
||||
"east_m": person_pos[0],
|
||||
"north_m": person_pos[1],
|
||||
"up_m": person_pos[2],
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
|
||||
def _cli_check() -> None:
|
||||
"""Validate RuViewOccDataset with synthetic data."""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
scene_dir = Path(tmpdir) / "scene_000"
|
||||
scene_dir.mkdir()
|
||||
|
||||
# Write 20 synthetic snapshots: person walks east at 0.5 m/frame
|
||||
for i in range(20):
|
||||
snap = _make_synthetic_snapshot(person_pos=(float(i) * 0.5, 2.0, 0.0))
|
||||
(scene_dir / f"frame_{i:06d}.json").write_text(json.dumps(snap))
|
||||
|
||||
ds = RuViewOccDataset(tmpdir, return_len=16)
|
||||
print(f"Dataset length: {len(ds)} windows")
|
||||
assert len(ds) == 5, f"Expected 5 windows, got {len(ds)}"
|
||||
|
||||
sample = ds[0]
|
||||
occ = sample["target_occs"]
|
||||
print(f"target_occs shape: {occ.shape} dtype: {occ.dtype}")
|
||||
assert occ.shape == (16, GRID_H, GRID_W, GRID_D)
|
||||
|
||||
# Check person voxels present in first frame
|
||||
assert (occ[0] == PERSON_CLASS).any(), "No person voxels in frame 0"
|
||||
print(f"Person voxels in frame 0: {(occ[0] == PERSON_CLASS).sum()}")
|
||||
|
||||
# Check floor voxels
|
||||
assert (occ[0, :, :, 0] == FLOOR_CLASS).any(), "No floor in frame 0"
|
||||
|
||||
# Check rel_poses are zeros
|
||||
assert (sample["rel_poses"] == 0).all(), "rel_poses should be all zeros"
|
||||
|
||||
print("rel_poses shape:", sample["rel_poses"].shape, "— all zeros:", (sample["rel_poses"] == 0).all())
|
||||
print("\nVALIDATION PASSED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="RuViewOccDataset — Phase 3 domain adapter")
|
||||
parser.add_argument("--snapshots", type=str, default=None, help="Snapshot directory")
|
||||
parser.add_argument("--check", action="store_true", help="Run synthetic validation")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.check:
|
||||
_cli_check()
|
||||
elif args.snapshots:
|
||||
ds = RuViewOccDataset(args.snapshots)
|
||||
print(f"Loaded {len(ds)} windows from {args.snapshots}")
|
||||
if len(ds) > 0:
|
||||
s = ds[0]
|
||||
print(f" target_occs: {s['target_occs'].shape}")
|
||||
print(f" rel_poses: {s['rel_poses'].shape}")
|
||||
else:
|
||||
parser.print_help()
|
||||
Generated
+31
-4
@@ -10660,10 +10660,10 @@ dependencies = [
|
||||
"criterion",
|
||||
"wifi-densepose-bfld",
|
||||
"wifi-densepose-core",
|
||||
"wifi-densepose-geo",
|
||||
"wifi-densepose-geo 0.1.0",
|
||||
"wifi-densepose-ruvector",
|
||||
"wifi-densepose-signal",
|
||||
"wifi-densepose-worldgraph",
|
||||
"wifi-densepose-worldgraph 0.3.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10678,6 +10678,20 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-geo"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "092ea59d81e7be76d6d9c2d81628c1dbe768fd77591f0e82dd3c80e2963ff04a"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"reqwest 0.12.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-hardware"
|
||||
version = "0.3.0"
|
||||
@@ -10931,7 +10945,20 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"wifi-densepose-geo",
|
||||
"wifi-densepose-geo 0.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wifi-densepose-worldgraph"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13ad8df7b323061ed7afae1917dac7eedfbd24a463a668a55a16cde79df067e2"
|
||||
dependencies = [
|
||||
"petgraph",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"wifi-densepose-geo 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10942,7 +10969,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"wifi-densepose-worldgraph",
|
||||
"wifi-densepose-worldgraph 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user