Files
walkersim/sim/mujoco/walker_sim/train_ppo.py

49 lines
1.3 KiB
Python

from __future__ import annotations
import argparse
from pathlib import Path
try:
from stable_baselines3 import PPO
except Exception as err: # pragma: no cover
raise RuntimeError(
"walker_sim.train_ppo requires stable-baselines3. Install with: uv sync --extra rl"
) from err
from .env import WalkerEnv
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Train PPO on WalkerEnv")
p.add_argument("--steps", type=int, default=200_000, help="Training timesteps")
p.add_argument("--save", type=str, default="../../artifacts/ppo_walker.zip", help="Output model path")
p.add_argument("--model", type=str, default=None, help="MJCF path override")
return p.parse_args()
def main() -> None:
args = parse_args()
save_path = Path(args.save)
save_path.parent.mkdir(parents=True, exist_ok=True)
env = WalkerEnv(model_path=args.model)
model = PPO(
"MlpPolicy",
env,
verbose=1,
learning_rate=3e-4,
n_steps=2048,
batch_size=128,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.0,
)
model.learn(total_timesteps=args.steps)
model.save(str(save_path))
print(f"Saved PPO model to {save_path}")
if __name__ == "__main__":
main()