wandb是tensorboard plus版,它有几个优点:
1.代码量少,几行代码就能把你想记录的各个实验数据全部变成图表形式
2.只要代码在跑,不管你在哪。登录网站就能看到实时更新的实验结果
3.自动调参。可以自己设定想调的参数范围,内置搜索方法可以帮你找到优秀的解。而且支持多进程,很多机器一起跑,效率拉满。
基础使用步骤:
1.安装
pip install wandb
2.注册和登录
wandb login
wandb会提示你login,根据它给你的链接去网站注册一个账号。然后会给你一个API,输入即可。
3.基本用法步骤
1.初始化:
import wandb
# 初始化wandb,project是项目名,config是实验参数,参数要设为字典形式
wandb.init(project='pong-dqn', config={
"lr": 0.0001,
"batch_size": 64,
"epsilon_decay": 1000000,
...
})
#注:代码中引用的参数写成:wandb.config.xxx,如学习率wandb.config.lr
2.记录指标:(这里记录的指标都会被自动做为图表,任何一个指标都可以作为横坐标)
wandb.log({"epoch": i_ep, "reward": ep_reward, "Epislon": agent.epsilon, 'ep_step': ep_step})
3.保存和加载模型
torch.save(model.state_dict(), "dqn_pong_model.pth")
wandb.save("dqn_pong_model.pth")
4.加载模型时,可以直接从wandb
运行的文件夹中加载:
model.load_state_dict(torch.load("dqn_pong_model.pth"))
5.结束运行
wandb.finish()
这行代码不加好像也没事。
运行代码后,代码会提示你在一个网站上查看结果。
最后的可视化效果:(红箭头处改变横坐标指标,再往右一个选项可改变曲线的平滑度)
这里我用epoch为横坐标,但是loss的记录以step为单位所以无法显示,改成step就可以看到loss了。
4.调参方法
1.在代码的文件夹处创建一个配置文件(通常是YAML格式)在这个文件中定义你想要自动调整的超参数及其搜索空间(其余的超参数在代码里)例如:
program: dqn.py#你要运行的代码名,里面已经做了以上3中的基本用法步骤
method: bayes # 选择搜索算法,如随机搜索(random)、网格搜索(grid)或贝叶斯优化(bayes)
metric:
name: reward #定义优化目标
goal: maximize # maximize/minimize代表最大/最小化目标
parameters:
lr:#连续搜索空间写法
min: 0.00001
max: 0.0001
batch_size:
values: [32, 64, 128]#离散搜索空间写法
target_update:
values: [100,500,1000,2000]
memory_capacity:
values: [100000,200000]
注:随机搜索在超参数空间内随机选择参数组合来进行搜索。网格搜索通过遍历给定超参数空间的所有组合来寻找最优参数的方法。贝叶斯优化实现更为复杂,计算复杂度高,但是通常能够更快地找到最优解。
参数维度较低(少量参数),且有足够的计算资源和时间进行穷举搜索选网格搜索。参数维度较高,参数空间大,预算有限,想要在较短时间内快速探索参数空间选随机搜索。参数维度适中或较高,需要找到尽可能优的参数组合,且计算资源充足选贝叶斯优化。
2.运行命令
wandb sweep --entity 你的wandb账户名 --project 你的项目名 调参配置文件名.yaml
蓝色网址是可以查看结果的网址,黄色命令是在命令行上执行的命令,在代码所在文件夹执行命令即可,执行后机器会自动尝试调参。
如果你有多台机器可用,你可以在每台机器上启动一个或多个agent。在每台机器上,使用相同的sweep_id
运行上述wandb agent
命令。这样,你可以利用所有机器的计算资源来并行化实验。
用两个机器一起运行命令后可以看到实时更新的结果:
这里还能监控显存消耗,cpu温度等信息:
如果你想自动给每次调参命名:
run = wandb.init(...)
# 获取由Sweep自动注入的参数
params = run.config
# 构建基于参数的实验名称
experiment_name = f"lr={params.learning_rate}_bs={params.batch_size}_opt={params.optimizer}"
# 更新实验的名称
wandb.run.name = experiment_name
wandb.run.save()
5.常见错误:
报错:Network error (ProxyError), entering retry loop.
把代理关上即可。