Files
walkersim/sim/mujoco/walker_sim/env.py

90 lines
3.3 KiB
Python

from __future__ import annotations
import numpy as np
try:
import gymnasium as gym
from gymnasium import spaces
except Exception as err: # pragma: no cover
raise RuntimeError(
"walker_sim.env requires gymnasium. Install with: uv sync --extra rl"
) from err
import mujoco
from .sim import DEFAULT_MODEL_PATH
class WalkerEnv(gym.Env):
metadata = {"render_modes": []}
def __init__(self, model_path: str | None = None, frame_skip: int = 4, max_steps: int = 2000):
super().__init__()
self.model = mujoco.MjModel.from_xml_path(str(model_path or DEFAULT_MODEL_PATH))
self.data = mujoco.MjData(self.model)
self.frame_skip = int(frame_skip)
self.max_steps = int(max_steps)
self.step_count = 0
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(self.model.nu,), dtype=np.float32)
obs_dim = self.model.nq + self.model.nv
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)
self._ctrl_min = np.array(self.model.actuator_ctrlrange[:, 0], dtype=np.float64)
self._ctrl_max = np.array(self.model.actuator_ctrlrange[:, 1], dtype=np.float64)
self._torso_gid = self.model.geom("torso_geom").id
self._floor_gid = self.model.geom("floor").id
def _torso_hit_floor(self) -> bool:
for i in range(self.data.ncon):
c = self.data.contact[i]
if (c.geom1 == self._torso_gid and c.geom2 == self._floor_gid) or (
c.geom1 == self._floor_gid and c.geom2 == self._torso_gid
):
return True
return False
def _obs(self) -> np.ndarray:
return np.concatenate([self.data.qpos, self.data.qvel]).astype(np.float32)
def reset(self, seed: int | None = None, options: dict | None = None):
super().reset(seed=seed)
self.step_count = 0
mujoco.mj_resetData(self.model, self.data)
qpos_noise = self.np_random.normal(0.0, 0.01, size=self.model.nq)
qvel_noise = self.np_random.normal(0.0, 0.01, size=self.model.nv)
self.data.qpos[:] = self.data.qpos + qpos_noise
self.data.qvel[:] = self.data.qvel + qvel_noise
mujoco.mj_forward(self.model, self.data)
return self._obs(), {}
def step(self, action: np.ndarray):
self.step_count += 1
action = np.asarray(action, dtype=np.float64)
action = np.clip(action, -1.0, 1.0)
ctrl = self._ctrl_min + 0.5 * (action + 1.0) * (self._ctrl_max - self._ctrl_min)
self.data.ctrl[:] = ctrl
x_before = float(self.data.body("torso").xpos[0])
for _ in range(self.frame_skip):
mujoco.mj_step(self.model, self.data)
x_after = float(self.data.body("torso").xpos[0])
dt = self.model.opt.timestep * self.frame_skip
forward_vel = (x_after - x_before) / dt
torso_z = float(self.data.body("torso").xpos[2])
ctrl_cost = 1e-4 * float(np.dot(self.data.ctrl, self.data.ctrl))
reward = 1.2 * forward_vel + 0.4 - ctrl_cost
terminated = self._torso_hit_floor() or torso_z < 0.08 or torso_z > 2.0
truncated = self.step_count >= self.max_steps
info = {
"forward_vel": forward_vel,
"torso_z": torso_z,
"ctrl_cost": ctrl_cost,
}
return self._obs(), reward, terminated, truncated, info