feat(cog-pose): per-room LoRA calibration adapter in the Rust inference path

Ports the calibration mechanism (ADR-150 §3.5-3.6, reference impl in
aether-arena/calibration/) into the real product pose engine. The Candle
InferenceEngine now loads an optional per-room adapter safetensors and
applies low-rank deltas (y + (x.A).B) on the fc1/fc2 head at inference.
Architecture-agnostic LoRA; base behaviour unchanged when no adapter.
New API: with_weights_and_adapter(), is_calibrated(). Tested: adapter
detection + output-change integration test (6/6 pass).

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv
2026-05-31 02:26:48 -04:00
parent 4db727649a
commit 3760db6c9a
2 changed files with 147 additions and 6 deletions
+77 -6
View File
@@ -46,6 +46,40 @@ impl PoseOutput {
}
}
/// Per-room LoRA calibration adapter (ADR-150 §3.53.6). Low-rank deltas on the pose
/// head: `delta = (x · A) · B`, with `A:[in,r]`, `B:[r,out]` (scale baked into `B` at
/// save time). A handful of labeled in-room samples fit this ~few-KB adapter and recover
/// SOTA-level pose for an unseen room/person, on top of the frozen shared base.
/// Adapter safetensors keys: `fc1.a`, `fc1.b`, `fc2.a`, `fc2.b` (any subset).
#[derive(Clone)]
struct PoseLora {
fc1: Option<(Tensor, Tensor)>,
fc2: Option<(Tensor, Tensor)>,
}
impl PoseLora {
/// Load from an adapter safetensors. Missing layer keys are simply skipped.
fn load(path: &Path, device: &Device) -> candle_core::Result<Self> {
let t = candle_core::safetensors::load(path, device)?;
let pair = |a: &str, b: &str| match (t.get(a), t.get(b)) {
(Some(x), Some(y)) => Some((x.clone(), y.clone())),
_ => None,
};
Ok(Self {
fc1: pair("fc1.a", "fc1.b"),
fc2: pair("fc2.a", "fc2.b"),
})
}
/// `y + (x · A) · B` when an adapter for this layer is present, else `y` unchanged.
fn apply(slot: &Option<(Tensor, Tensor)>, x: &Tensor, y: Tensor) -> candle_core::Result<Tensor> {
match slot {
Some((a, b)) => y + x.matmul(a)?.matmul(b)?,
None => Ok(y),
}
}
}
/// Internal model — mirrors the training script's `PoseModel` exactly.
struct PoseNet {
c1: Conv1d,
@@ -53,6 +87,8 @@ struct PoseNet {
c3: Conv1d,
fc1: Linear,
fc2: Linear,
/// Optional per-room calibration adapter (none = shared base behaviour).
adapter: Option<PoseLora>,
}
impl PoseNet {
@@ -108,20 +144,31 @@ impl PoseNet {
c3,
fc1,
fc2,
adapter: None,
})
}
/// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`.
/// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`. Applies the per-room
/// LoRA calibration adapter on the head layers when one is attached.
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let h = self.c1.forward(x)?.relu()?;
let h = self.c2.forward(&h)?.relu()?;
let h = self.c3.forward(&h)?.relu()?;
// Global average pool over time dim (last dim) -> [B, 128]
let h = h.mean(2)?;
let h = self.fc1.forward(&h)?.relu()?;
let h = self.fc2.forward(&h)?;
let pooled = h.mean(2)?;
// fc1 (+ adapter delta) -> ReLU
let mut h1 = self.fc1.forward(&pooled)?;
if let Some(ad) = &self.adapter {
h1 = PoseLora::apply(&ad.fc1, &pooled, h1)?;
}
let h1 = h1.relu()?;
// fc2 (+ adapter delta)
let mut h2 = self.fc2.forward(&h1)?;
if let Some(ad) = &self.adapter {
h2 = PoseLora::apply(&ad.fc2, &h1, h2)?;
}
// sigmoid -> keep in [0, 1]
candle_nn::ops::sigmoid(&h)
candle_nn::ops::sigmoid(&h2)
}
}
@@ -148,6 +195,17 @@ impl InferenceEngine {
/// in `cog-pose-estimation run`). If `weights_path` is `None`, the
/// stub fallback is used.
pub fn with_weights(weights_path: Option<&Path>) -> Result<Self, Box<dyn std::error::Error>> {
Self::with_weights_and_adapter(weights_path, None)
}
/// Create an engine with a shared base **and an optional per-room calibration
/// adapter** (ADR-150 §3.5). The adapter is a tiny LoRA safetensors fitted from a
/// short labeled in-room capture (`aether-arena/calibration/calibrate.py`); attaching
/// it recovers SOTA-level pose in an unseen room/person. `None` = uncalibrated base.
pub fn with_weights_and_adapter(
weights_path: Option<&Path>,
adapter_path: Option<&Path>,
) -> Result<Self, Box<dyn std::error::Error>> {
let device = pick_device();
let inner = match weights_path {
Some(p) if p.exists() => {
@@ -158,7 +216,12 @@ impl InferenceEngine {
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)?
};
let net = PoseNet::new(vb)?;
let mut net = PoseNet::new(vb)?;
if let Some(ap) = adapter_path {
if ap.exists() {
net.adapter = Some(PoseLora::load(ap, &device)?);
}
}
Some(Arc::new(LoadedModel { net }))
}
_ => None,
@@ -166,6 +229,14 @@ impl InferenceEngine {
Ok(Self { inner, device })
}
/// Whether a per-room calibration adapter is currently attached.
pub fn is_calibrated(&self) -> bool {
self.inner
.as_ref()
.map(|m| m.net.adapter.is_some())
.unwrap_or(false)
}
/// Where the weights actually came from. Useful for the run.started event.
pub fn backend(&self) -> &'static str {
match (&self.inner, &self.device) {
@@ -63,6 +63,76 @@ fn real_weights_load_when_available() {
);
}
#[test]
fn per_room_adapter_changes_inference_output() {
// Build a minimal valid base + a non-trivial LoRA adapter in a tempdir, then verify
// the calibration adapter (ADR-150 §3.5) is detected and actually alters the output.
use candle_core::{DType, Device, Tensor};
use std::collections::HashMap;
let dev = Device::Cpu;
let dir = std::env::temp_dir().join(format!("cogpose_adapter_test_{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let base_p = dir.join("base.safetensors");
let adapter_p = dir.join("room.adapter.safetensors");
// --- base weights (random but finite) matching PoseNet's VarBuilder keys ---
let mut w: HashMap<String, Tensor> = HashMap::new();
let mut put = |k: &str, t: Tensor| {
w.insert(k.to_string(), t);
};
put("enc.c1.weight", Tensor::randn(0f32, 0.1, (64, 56, 3), &dev).unwrap());
put("enc.c1.bias", Tensor::zeros(64, DType::F32, &dev).unwrap());
put("enc.c2.weight", Tensor::randn(0f32, 0.1, (128, 64, 3), &dev).unwrap());
put("enc.c2.bias", Tensor::zeros(128, DType::F32, &dev).unwrap());
put("enc.c3.weight", Tensor::randn(0f32, 0.1, (128, 128, 3), &dev).unwrap());
put("enc.c3.bias", Tensor::zeros(128, DType::F32, &dev).unwrap());
put("head.fc1.weight", Tensor::randn(0f32, 0.1, (256, 128), &dev).unwrap());
put("head.fc1.bias", Tensor::zeros(256, DType::F32, &dev).unwrap());
put("head.fc2.weight", Tensor::randn(0f32, 0.1, (34, 256), &dev).unwrap());
put("head.fc2.bias", Tensor::zeros(34, DType::F32, &dev).unwrap());
candle_core::safetensors::save(&w, &base_p).unwrap();
// --- adapter: non-zero low-rank deltas on both head layers (scale baked into B) ---
let r = 4usize;
let mut ad: HashMap<String, Tensor> = HashMap::new();
ad.insert("fc1.a".into(), Tensor::randn(0f32, 0.5, (128, r), &dev).unwrap());
ad.insert("fc1.b".into(), Tensor::randn(0f32, 0.5, (r, 256), &dev).unwrap());
ad.insert("fc2.a".into(), Tensor::randn(0f32, 0.5, (256, r), &dev).unwrap());
ad.insert("fc2.b".into(), Tensor::randn(0f32, 0.5, (r, 34), &dev).unwrap());
candle_core::safetensors::save(&ad, &adapter_p).unwrap();
let base = InferenceEngine::with_weights(Some(&base_p)).expect("base load");
let cal = InferenceEngine::with_weights_and_adapter(Some(&base_p), Some(&adapter_p))
.expect("calibrated load");
assert!(!base.is_calibrated(), "base must report uncalibrated");
assert!(cal.is_calibrated(), "adapter engine must report calibrated");
// Non-zero input — a zero window would zero the LoRA delta (x·A·B = 0).
let win = cog_pose_estimation::inference::CsiWindow {
data: (0..INPUT_SUBCARRIERS * INPUT_TIMESTEPS)
.map(|i| ((i % 7) as f32 - 3.0) * 0.2)
.collect(),
};
let a = base.infer(&win).expect("base infer");
let b = cal.infer(&win).expect("calibrated infer");
assert!(a.is_finite() && b.is_finite());
let diff: f32 = a
.keypoints
.iter()
.zip(&b.keypoints)
.map(|(x, y)| (x - y).abs())
.sum();
assert!(
diff > 1e-4,
"per-room adapter must change the output (sum|Δ| = {diff})"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn manifest_roundtrips() {
let spec = ManifestSpec::embedded("pose-estimation", "0.0.1");