前面的博客介绍了ray.rllib中算法的配置和构建,也包含了算法训练的代码。 但是rllib中实现算法训练的方式不止一种,本博客对此进行介绍。很多教程使用 PPOTrainer 进行训练,但是 PPOTrainer 在最近的 ray 版本中已经取消了。
环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
方式1: algo.train()
rllib 中的 Algorithm 类自带了.train() 函数,实现算法训练,前面几个博客教程均是采用的这种方式。这里仅再提供一下示例, 不再赘述:
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print
## 配置算法
storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path ## 设置过程文件的存储路径
## 构建算法
algo = config.build()
## 训练
for i in range(3):
result = algo.train()
print(f"episode_{i}")
方式2:tune.Tuner()
以上方式只能固定训练超参数,不能对训练超参数寻优。ray中还有一个模块 tune, 专门用于算法训练过程中超参数调参。
在使用tune.Tuner()执行rllib算法训练时, 可以默认为tune背后自动执行了以下操作:
algo = PPOConfig().build() ## 构建算法
result = algo.train() ## 算法训练
print(pretty_print(result)) ## 每完成一次algo.train, 打印一次阶段性训练结果
algo.save_checkpoint() ## 保存训练模型
并且遍历了多个超参数组合,多次进行训练。直到达到停止训练的条件(自己配置)。
基于tune的rllib训练示例如下(代码篇幅比较大是因为添加的功能模块和注释比较多,后面的介绍主要以 方式一为主,所以对于这种方式,这里介绍的多一些):
import ray
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray import train, tune
import torch
import os
import shutil
from ray.tune.logger import pretty_print
import gymnasium as gym
ray.init()
#### 配置算法 ####
config = PPOConfig()
config = config.training(lr=tune.grid_search([0.01, 0.001]))
config = config.environment(env="CartPole-v1")
#### 配置 tune ####
## 准备 tune 的 stop_condition,多个条件之间是”或“的关系。有一个满足即停止训练。
## 'episode_reward_mean'关键字将在 ”ray-2.40”版中中被抛弃,
## 届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'
stop_condition = {
'episode_reward_mean':10, ## 这里设置的结束条件很宽松,所以能够快速结束训练。
"training_iteration":3
}
## 准备 tune 的过程文件存储路径
storage_path = "F:/codes/RLlib_study/ray_results"
os.makedirs(storage_path, exist_ok=True)
## 准备 tune 的 checkpoint_config
## tune 默认保存每个 algo.train() 训练得到的 checkpoint.
## 通过以下配置,可以对此进行自定义修改
checkpoint_config = train.CheckpointConfig(num_to_keep=None, ## 保存几个checkpoint, None 表示保存所有checkpoint
checkpoint_at_end=True) ## 是否在训练结束后保存 checkpoint.
## 配置 tuner
tuner = tune.Tuner(
PPO, ## 需要是一个 rllib 的 Algorithm 类, 从 ray.rllib.algorithms 导入, 也可以是自定义的,后面介绍
run_config = train.RunConfig(
stop = stop_condition,
checkpoint_config = checkpoint_config, ## 用于设置保存哪个checkpoint.
storage_path = storage_path, ## 如果不设置, 默认存储路径是 “~/ray-results” 或 “C:/用户/xxx/ray_results”
),
param_space=config, ## 这里定义了参数搜索调优空间,是一个 PPOConfig 对象
)
## 执行训练
results = tuner.fit() ## tuner 返回一个Result表格对象,该对象允许进一步分析训练结果并检索经过训练的智能体的checkpoint。
print("====训练结束====")
## 获取最佳训练结果
best_result = results.get_best_result(metric="episode_reward_mean", mode="max")
## 以 "episode_reward_mean" 为选择指标, 从results里面选择checkpoint, 选择模式是“max”,
## 'episode_reward_mean'关键字将在ray-2.40版中中被抛弃,届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'
## 从最佳训练结果中提取对应的 checkpoint , 并保存
checkpoint_save_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
os.makedirs(checkpoint_save_dir, exist_ok=True)
best_checkpoint = best_result.checkpoint
if best_checkpoint:
with best_checkpoint.as_directory() as checkpoint_dir:
print(f"====最佳模型路径位于:{checkpoint_dir}====")
## 把最佳模型转存到指定位置。
shutil.rmtree(checkpoint_save_dir)
shutil.copytree(checkpoint_dir,checkpoint_save_dir)
print(f"====保存最佳模型到:{checkpoint_save_dir}====")
## 加载保存的最佳模型
checkpoint_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"==== 加载最佳模型: {checkpoint_dir}")
## evaluate 模型
env_name = "CartPole-v1"
env = gym.make(env_name)
## 模型推断: method-1
step = 0
episode_reward = 0
terminated = truncated = False
obs,info = env.reset()
while not terminated and not truncated:
action = algo.compute_single_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
step += 1
print(f"step = {step}, reward = {reward}, action = {action}, obs = {obs}, episode_reward = {episode_reward}")