77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import mujoco
|
|
import numpy as np
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
DEFAULT_MODEL_PATH = ROOT / "model" / "walker.xml"
|
|
|
|
|
|
@dataclass
|
|
class StepLog:
|
|
t: float
|
|
near_crank_rad: float
|
|
far_crank_rad: float
|
|
torso_x: float
|
|
torso_z: float
|
|
near_f_x: float
|
|
near_f_z: float
|
|
near_g_x: float
|
|
near_g_z: float
|
|
far_f_x: float
|
|
far_f_z: float
|
|
far_g_x: float
|
|
far_g_z: float
|
|
|
|
|
|
def load_model(model_path: str | Path | None = None) -> tuple[mujoco.MjModel, mujoco.MjData]:
|
|
path = Path(model_path) if model_path else DEFAULT_MODEL_PATH
|
|
model = mujoco.MjModel.from_xml_path(str(path))
|
|
data = mujoco.MjData(model)
|
|
reset_to_stand(model, data)
|
|
return model, data
|
|
|
|
|
|
def reset_to_stand(model: mujoco.MjModel, data: mujoco.MjData) -> None:
|
|
key_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_KEY, "stand")
|
|
if key_id >= 0:
|
|
mujoco.mj_resetDataKeyframe(model, data, key_id)
|
|
else:
|
|
mujoco.mj_resetData(model, data)
|
|
mujoco.mj_forward(model, data)
|
|
|
|
|
|
def site_xyz(data: mujoco.MjData, site_name: str) -> np.ndarray:
|
|
return np.array(data.site(site_name).xpos, dtype=np.float64)
|
|
|
|
|
|
def torso_xyz(data: mujoco.MjData) -> np.ndarray:
|
|
return np.array(data.body("torso").xpos, dtype=np.float64)
|
|
|
|
|
|
def snapshot(data: mujoco.MjData) -> StepLog:
|
|
torso = torso_xyz(data)
|
|
near_f = site_xyz(data, "near_F_site")
|
|
near_g = site_xyz(data, "near_G_site")
|
|
far_f = site_xyz(data, "far_F_site")
|
|
far_g = site_xyz(data, "far_G_site")
|
|
return StepLog(
|
|
t=float(data.time),
|
|
near_crank_rad=float(data.joint("near_crank_joint").qpos[0]),
|
|
far_crank_rad=float(data.joint("far_crank_joint").qpos[0]),
|
|
torso_x=float(torso[0]),
|
|
torso_z=float(torso[2]),
|
|
near_f_x=float(near_f[0]),
|
|
near_f_z=float(near_f[2]),
|
|
near_g_x=float(near_g[0]),
|
|
near_g_z=float(near_g[2]),
|
|
far_f_x=float(far_f[0]),
|
|
far_f_z=float(far_f[2]),
|
|
far_g_x=float(far_g[0]),
|
|
far_g_z=float(far_g[2]),
|
|
)
|