66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
import tempfile
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Iterator
|
|
|
|
from .sim import DEFAULT_MODEL_PATH
|
|
|
|
|
|
FOOT_LATERAL_PATTERNS = [
|
|
(
|
|
"near_front_foot_lateral",
|
|
"-0.2795 0 -0.3535 -0.2795 {y:.4f} -0.3535",
|
|
),
|
|
(
|
|
"near_rear_foot_lateral",
|
|
"0.2801 0 -0.3522 0.2801 {y:.4f} -0.3522",
|
|
),
|
|
(
|
|
"far_front_foot_lateral",
|
|
"-0.2795 -{y:.4f} -0.3535 -0.2795 0 -0.3535",
|
|
),
|
|
(
|
|
"far_rear_foot_lateral",
|
|
"0.2801 -{y:.4f} -0.3522 0.2801 0 -0.3522",
|
|
),
|
|
]
|
|
|
|
|
|
def apply_foot_length_cm(model_xml: str, foot_cm: float) -> str:
|
|
if foot_cm <= 0:
|
|
raise ValueError("--foot-cm must be > 0")
|
|
|
|
length_m = foot_cm / 100.0
|
|
updated = model_xml
|
|
for geom_name, fromto_fmt in FOOT_LATERAL_PATTERNS:
|
|
pattern = rf'(<geom name="{geom_name}"[^>]*fromto=")[^"]+("[^>]*>)'
|
|
replacement = fromto_fmt.format(y=length_m)
|
|
updated, n = re.subn(pattern, lambda m: f"{m.group(1)}{replacement}{m.group(2)}", updated)
|
|
if n != 1:
|
|
raise RuntimeError(f"Could not update geom '{geom_name}' fromto")
|
|
return updated
|
|
|
|
|
|
def describe_model_config(base_model: str | None, foot_cm: float | None) -> str:
|
|
model_path = Path(base_model) if base_model else DEFAULT_MODEL_PATH
|
|
foot_text = "default" if foot_cm is None else f"{foot_cm:.1f}"
|
|
return f"model={model_path} foot_cm={foot_text}"
|
|
|
|
|
|
@contextmanager
|
|
def resolved_model_path(base_model: str | None, foot_cm: float | None) -> Iterator[str]:
|
|
model_path = Path(base_model) if base_model else DEFAULT_MODEL_PATH
|
|
if foot_cm is None:
|
|
yield str(model_path)
|
|
return
|
|
|
|
base_xml = model_path.read_text(encoding="utf-8")
|
|
tuned_xml = apply_foot_length_cm(base_xml, foot_cm)
|
|
with tempfile.TemporaryDirectory(prefix="walker_foot_model_") as tmpdir:
|
|
tmp_model = Path(tmpdir) / f"walker_foot_{int(round(foot_cm * 10)):04d}mm.xml"
|
|
tmp_model.write_text(tuned_xml, encoding="utf-8")
|
|
yield str(tmp_model)
|