49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
try:
|
|
from stable_baselines3 import PPO
|
|
except Exception as err: # pragma: no cover
|
|
raise RuntimeError(
|
|
"walker_sim.train_ppo requires stable-baselines3. Install with: uv sync --extra rl"
|
|
) from err
|
|
|
|
from .env import WalkerEnv
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(description="Train PPO on WalkerEnv")
|
|
p.add_argument("--steps", type=int, default=200_000, help="Training timesteps")
|
|
p.add_argument("--save", type=str, default="../../artifacts/ppo_walker.zip", help="Output model path")
|
|
p.add_argument("--model", type=str, default=None, help="MJCF path override")
|
|
return p.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
save_path = Path(args.save)
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
env = WalkerEnv(model_path=args.model)
|
|
model = PPO(
|
|
"MlpPolicy",
|
|
env,
|
|
verbose=1,
|
|
learning_rate=3e-4,
|
|
n_steps=2048,
|
|
batch_size=128,
|
|
gamma=0.99,
|
|
gae_lambda=0.95,
|
|
clip_range=0.2,
|
|
ent_coef=0.0,
|
|
)
|
|
model.learn(total_timesteps=args.steps)
|
|
model.save(str(save_path))
|
|
print(f"Saved PPO model to {save_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|