【資料圖】
機(jī)器之心專欄
作者:第四范式強(qiáng)化學(xué)習(xí)團(tuán)隊(duì)
強(qiáng)化學(xué)習(xí)研究框架 OpenRL 是基于 PyTorch 開發(fā)的,已經(jīng)在 GitHub 上開源。
pip install openrl
conda install -c openrl openrl
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
env = make ("CartPole-v1", env_num=9) # 創(chuàng)建環(huán)境,并設(shè)置環(huán)境并行數(shù)為 9
net = Net (env) # 創(chuàng)建神經(jīng)網(wǎng)絡(luò)
agent = Agent (net) # 初始化智能體
agent.train (total_time_steps=20000) # 開始訓(xùn)練,并設(shè)置環(huán)境運(yùn)行總步數(shù)為 20000
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
def train ():
# 創(chuàng)建 MPE 環(huán)境,使用異步環(huán)境,即每個(gè)智能體獨(dú)立運(yùn)行
env = make (
"simple_spread",
env_num=100,
asynchronous=True,
)
# 創(chuàng)建 神經(jīng)網(wǎng)絡(luò),使用 GPU 進(jìn)行訓(xùn)練
net = Net (env, device="cuda")
agent = Agent (net) # 初始化訓(xùn)練器
# 開始訓(xùn)練
agent.train (total_time_steps=5000000)
# 保存訓(xùn)練完成的智能體
agent.save ("./ppo_agent/")
if __name__ == "__main__":
train ()
# mpe_ppo.yaml
seed: 0 # 設(shè)置 seed,保證每次實(shí)驗(yàn)結(jié)果一致
lr: 7e-4 # 設(shè)置學(xué)習(xí)率
episode_length: 25 # 設(shè)置每個(gè) episode 的長度
use_recurrent_policy: true # 設(shè)置是否使用 RNN
use_joint_action_loss: true # 設(shè)置是否使用 JRPO 算法
use_valuenorm: true # 設(shè)置是否使用 value normalization
python train_ppo.py --config mpe_ppo.yaml
env = make ("simple_spread", env_num=9, render_mode="group_human")
from openrl.envs.wrappers import GIFWrapper
env = GIFWrapper (env, "test_simple_spread.gif")
# test_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.envs.wrappers import GIFWrapper # 用于生成 gif
def test ():
# 創(chuàng)建 MPE 環(huán)境
env = make ( "simple_spread", env_num=4)
# 使用 GIFWrapper,用于生成 gif
env = GIFWrapper (env, "test_simple_spread.gif")
agent = Agent (Net (env)) # 創(chuàng)建 智能體
# 保存智能體
agent.save ("./ppo_agent/")
# 加載智能體
agent.load ("./ppo_agent/")
# 開始測(cè)試
obs, _ = env.reset ()
while True:
# 智能體根據(jù) observation 預(yù)測(cè)下一個(gè)動(dòng)作
action, _ = agent.act (obs)
obs, r, done, info = env.step (action)
if done.any ():
break
env.close ()
if __name__ == "__main__":
test ()
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.configs.config import create_config_parser
def train ():
# 添加讀取配置文件的代碼
cfg_parser = create_config_parser ()
cfg = cfg_parser.parse_args ()
# 創(chuàng)建 NLP 環(huán)境
env = make ("daily_dialog",env_num=2,asynchronous=True,cfg=cfg,)
net = Net (env, cfg=cfg, device="cuda")
agent = Agent (net)
agent.train (total_time_steps=5000000)
if __name__ == "__main__":
train ()
# nlp_ppo.yaml
data_path: daily_dialog # 數(shù)據(jù)集路徑
env: # 環(huán)境所用到的參數(shù)
args: {"tokenizer_path": "gpt2"} # 讀取 tokenizer 的路徑
seed: 0 # 設(shè)置 seed,保證每次實(shí)驗(yàn)結(jié)果一致
lr: 1e-6 # 設(shè)置 policy 模型的學(xué)習(xí)率
critic_lr: 1e-6 # 設(shè)置 critic 模型的學(xué)習(xí)率
episode_length: 20 # 設(shè)置每個(gè) episode 的長度
use_recurrent_policy: true
# nlp_ppo.yaml
# 預(yù)訓(xùn)練模型路徑
model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
use_share_model: true # 策略網(wǎng)絡(luò)和價(jià)值網(wǎng)絡(luò)是否共享模型
ppo_epoch: 5 # ppo 訓(xùn)練迭代次數(shù)
data_path: daily_dialog # 數(shù)據(jù)集名稱或者路徑
env: # 環(huán)境所用到的參數(shù)
args: {"tokenizer_path": "gpt2"} # 讀取 tokenizer 的路徑
lr: 1e-6 # 設(shè)置 policy 模型的學(xué)習(xí)率
critic_lr: 1e-6 # 設(shè)置 critic 模型的學(xué)習(xí)率
episode_length: 128 # 設(shè)置每個(gè) episode 的長度
num_mini_batch: 20
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.configs.config import create_config_parser
from openrl.modules.networks.policy_value_network_gpt import (
PolicyValueNetworkGPT as PolicyValueNetwork,
)
def train ():
# 添加讀取配置文件的代碼
cfg_parser = create_config_parser ()
cfg = cfg_parser.parse_args ()
# 創(chuàng)建 NLP 環(huán)境
env = make ("daily_dialog",env_num=2,asynchronous=True,cfg=cfg,)
# 創(chuàng)建自定義神經(jīng)網(wǎng)絡(luò)
model_dict = {"model": PolicyValueNetwork}
net = Net (env, cfg=cfg, model_dict=model_dict)
# 創(chuàng)建訓(xùn)練智能體
agent = Agent (net)
agent.train (total_time_steps=5000000)
if __name__ == "__main__":
train ()
model_dict = {
"policy":CustomPolicyNetwork,
"critic":CustomValueNetwork,
}
net = Net (env, model_dict=model_dict)
# nlp_ppo.yaml
reward_class:
id: NLPReward # 獎(jiǎng)勵(lì)模型名稱
args: {
# 用于意圖判斷的模型的名稱或路徑
"intent_model": rajkumarrrk/roberta-daily-dialog-intent-classifier,
# 用于計(jì)算 KL 散度的預(yù)訓(xùn)練模型的名稱或路徑
"ref_model": roberta-base, # 用于意圖判斷的 tokenizer 的名稱或路徑
}
# train_ppo.py
fromopenrl.rewards.nlp_rewardimportCustomReward
from openrl.rewards import RewardFactory
RewardFactory.register("CustomReward",CustomReward)
reward_class:
id:"CustomReward"#自定義獎(jiǎng)勵(lì)模型名稱
args: {} # 用戶自定義獎(jiǎng)勵(lì)函數(shù)可能用到的參數(shù)
# nlp_ppo.yaml
vec_info_class:
id: "NLPVecInfo" # 調(diào)用 NLPVecInfo 類以打印 NLP 任務(wù)中獎(jiǎng)勵(lì)函數(shù)的信息
# 設(shè)置 wandb 信息
wandb_entity: openrl # 這里用于指定 wandb 團(tuán)隊(duì)名稱,請(qǐng)把 openrl 替換為你自己的團(tuán)隊(duì)名稱
experiment_name: train_nlp # 這里用于指定實(shí)驗(yàn)名稱
run_dir: ./run_results/ # 這里用于指定實(shí)驗(yàn)數(shù)據(jù)保存的路徑
log_interval: 1 # 這里用于指定每隔多少個(gè) episode 上傳一次 wandb 數(shù)據(jù)
# 自行填寫其他參數(shù)...
# train_ppo.py
agent.train (total_time_steps=100000, use_wandb=True)
# train_ppo.py # 注冊(cè)自定義輸出信息類
VecInfoFactory.register("CustomVecInfo",CustomVecInfo)
# nlp_ppo.yaml
vec_info_class:
id:"CustomVecInfo"#調(diào)用自定義CustomVecInfo類以輸出自定義信息
# nlp_ppo.yaml
use_amp: true # 開啟混合精度訓(xùn)練
# chat.py
from openrl.runners.common import ChatAgent as Agent
def chat ():
agent = Agent.load ("./ppo_agent", tokenizer="gpt2",)
history = []
print ("Welcome to OpenRL!")
while True:
input_text = input ("> User:")
if input_text == "quit":
break
elif input_text == "reset":
history = []
print ("Welcome to OpenRL!")
continue
response = agent.chat (input_text, history)
print (f"> OpenRL Agent: {response}")
history.append (input_text)
history.append (response)
if __name__ == "__main__":
chat ()
OpenRL框架是由OpenRL Lab團(tuán)隊(duì)開發(fā),該團(tuán)隊(duì)是第四范式公司旗下的強(qiáng)化學(xué)習(xí)研究團(tuán)隊(duì)。第四范式長期致力于強(qiáng)化學(xué)習(xí)的研發(fā)和工業(yè)應(yīng)用。為了促進(jìn)強(qiáng)化學(xué)習(xí)的產(chǎn)學(xué)研一體化,第四范式成立了OpenRL Lab研究團(tuán)隊(duì),目標(biāo)是先進(jìn)技術(shù)開源和人工智能前沿探索。成立不到一年,OpenRL Lab團(tuán)隊(duì)已經(jīng)在AAMAS發(fā)表過三篇論文,參加谷歌足球游戲 11 vs 11比賽并獲得第三的成績。團(tuán)隊(duì)提出的TiZero智能體,實(shí)現(xiàn)了首個(gè)從零開始,通過課程學(xué)習(xí)、分布式強(qiáng)化學(xué)習(xí)、自博弈等技術(shù)完成谷歌足球全場(chǎng)游戲智能體的訓(xùn)練:
?THE END
轉(zhuǎn)載請(qǐng)聯(lián)系本公眾號(hào)獲得授權(quán)
投稿或?qū)で髨?bào)道:content@jiqizhixin.com
Copyright @ 2015-2022 海外生活網(wǎng)版權(quán)所有 備案號(hào): 滬ICP備2020036824號(hào)-21 聯(lián)系郵箱:562 66 29@qq.com