Dynadot Jtti 搬瓦工 腾讯云

PyTorch PPO 源码解读 (pytorch-a2c-ppo-acktr-gail)

GigsGigsCloud

论文告一段落,今天开始会陆续整理一下之前论文用到的一些代码,做一个后续整理工作,以备之后有需要的时候再用。本文整理一下 PyTorch PPO 源码解读,这份解读对快速理解 PPO 代码的帮助还是挺大的,之前了解过 PPO 但是还没有写过代码的朋友们可以看一下。

Introduction

本文介绍的Proximal Policy Optimization (PPO)实现是基于PyTorch的,其Github地址在这里。实际上它一共实现了三个算法,包括PPO、A2C以及ACKTR。这份代码的逻辑抽象做得不错,三个算法共用了很多代码,因此看懂了PPO对于理解另外两个算法的实现有很大帮助。

这份PPO代码依赖于OpenAI baselines,主要用到了其并行环境的wrapper。由于PPO和OpenAI baselines的代码一直在更新,所以最新的代码跟本文可能有所出入。本文所述代码对应的commit id前缀为:

  • PPO:3aea397
  • OpenAI baselines:f272969

下面先介绍代码的逻辑框架,然后再看具体代码。具体代码部分以一个类似于深度优先的方式展开。

Logical framework

这份代码按照RL的思想,将代码分成了environment和robot(之所以不用agent这个词,主要是为了跟下面的变量名做区分)两大模块。environment部分主要是envs这个变量,它通过OpenAI baselines的wrapper实现了多环境并行化。robot部分分为了三个子模块,第一个是parametric policy变量actor_critic;第二个是用来保存trajectory的rollouts变量;最后一个是用于参数更新的RL algorithm变量agent。总结一下,其逻辑框架如下:

  • environment
    • multiprocess wrapper: envs
  • robot
    • parametric policy: actor_critic
    • trajectory: rollouts
    • RL algorithm: agent

整个交互流程基本上是robot从envs中得到一个observation,然后根据actor_critic作出反应,并且将交互过程记录到rollouts中,接着使用agent根据rollouts中的信息来更新actor_critic

Source code

main

# PPO_ROOT/main.py
def main():
    ...
    
    # 创建并行交互环境envs
    envs = [make_env(args.env_name, args.seed, i, args.log_dir, args.add_timestep)
                for i in range(args.num_processes)]
    
    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        ...
    
    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs, gamma=args.gamma)
    
    ...
    
    for j in range(num_updates):
        for step in range(args.num_steps):
            # 通过rollouts获得t时刻的的信息(e.g., o_t),使用actor_critic(i.e., policy)来计算a_t
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                        rollouts.observations[step],
                        rollouts.states[step],
                        rollouts.masks[step])
    
            ...
    
            # 使用a_t跟envs(i.e., environment)进行交互
            obs, reward, done, info = envs.step(cpu_actions)
            
            ...
    
            # 将交互信息存入rollouts中
            rollouts.insert(current_obs, states, action, action_log_prob, value, reward, masks)
    
        ...
        
        # 根据rollouts提供的信息,使用RL algorithm(i.e., PPO)对actor_critic(i.e., policy)进行参数更新
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
    
        ...

以上代码将main函数的核心部分提了出来。可见,main函数主要有以下核心步骤

  1. 创建并行交互环境envs
  2. 通过actor_critic与环境进行交互;
  3. 将交互信息存入rollouts
  4. 使用agent进行参数更新。

接下来分别展开介绍这四个变量。


envs

main函数中可见,envs变量通过make_envSubprocVecEnvVecNormalize得到,接下来分别介绍这三个函数。

make_env

# PPO_ROOT/envs.py
def make_env(env_id, seed, rank, log_dir, add_timestep):
    def _thunk():
    	# 最基本的gym env
        if env_id.startswith("dm"):
        	...
        else:
            env = gym.make(env_id)
        
        ...
        # 加了一个Monitor wrapper(注意这个Monitor不是保存视频的gym.Wrappers.Monitor)来记录信息(默认是/tmp/gym/)
        if log_dir is not None:
            env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
        ...
        return env
    # 返回的是一个函数而不是env
    return _thunk

从上述代码可以看出,make_env函数返回了一个函数_thunk,调用该函数可以得到一个gym env。(make_env选择返回一个函数而不直接返回一个gym env的原因我不太清楚,我推测是跟序列化和反序列化有关。从下面的SubprocVecEnv可见,make_env的返回结果要传给一个在另一个进程运行的函数worker,这个过程应该是需要做序列化和反序列化的,可能传函数比较好实现或者比较高效?)

SubprocVecEnv

# BASELINE_ROOT/common/vec_env/subproc_vec_env.py
def worker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    # 实例化env,相当于_thunk()(之所以需要.x()是因为用了CloudpickleWrapper,估计是为了处理函数序列化和反序列化的一些问题)
    env = env_fn_wrapper.x()
    # 相当于一个一直在跑的env服务
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, reward, done, info = env.step(data)
            if done:
                ob = env.reset()
            remote.send((ob, reward, done, info))
        elif cmd == 'reset':
            ob = env.reset()
            remote.send(ob)
        elif cmd == 'render':
            remote.send(env.render(mode='rgb_array'))
        elif cmd == 'close':
            remote.close()
            break
        elif cmd == 'get_spaces':
            remote.send((env.observation_space, env.action_space))
        else:
            raise NotImplementedError

SubprocVecEnv首先创建不同的subprocess去运行worker,每一个worker基本上相当于一个env服务,每收到一个指令就对env执行相应的操作并且返回信息。worker的参数env_fn_wrapper就是_thunk加了一个CloudpickleWrapper,该wrapper好像是用来辅助序列化和反序列化的。

# BASELINE_ROOT/common/vec_env/subproc_vec_env.py
class SubprocVecEnv(VecEnv):
    ...
    # 给subprocess发送step指令
    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True
    # 接收并拼接subprocess返回的交互信息
    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos
    ...
    # 这个函数在基类VecEnv中
    def step(self, actions):
        self.step_async(actions)
        return self.step_wait()

有了上面的subprocee,主进程中的SubprocVecEnv只需要重载stepreset等关键函数为发送相应指令给subprocess,然后将数据拼接起来,就实现了跟环境的并行交互。

VecNormalize

# BASELINE_ROOT/common/vec_env/vec_normalize.py
class VecNormalize(VecEnvWrapper):
    def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        VecEnvWrapper.__init__(self, venv)
        # mean, std计算器
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=()) if ret else None
    ...
    def step_wait(self):
        obs, rews, news, infos = self.venv.step_wait()
        self.ret = self.ret * self.gamma + rews
        # 对obs做normalization(mean, std),具体代码见下面
        obs = self._obfilt(obs)
        # 对reward做normalization(不用mean只用std,而且用的是return的std)
        if self.ret_rms:
            self.ret_rms.update(self.ret)
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        return obs, rews, news, infos
    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs
    ...

从上述代码可见,VecNormalize也是一个wrapper,主要功能是对observation和reward做normalization


actor_critic

# PPO_ROOT/model.py
class Policy(nn.Module):
    ...
    def act(self, inputs, states, masks, deterministic=False):
    	# self.base集actor和critic于一身
    	# value:V(o_t)
    	# actor_features:actor网络的倒数第二层
    	# states:RNN的hidden states
        value, actor_features, states = self.base(inputs, states, masks)
        # dist最后的action分布\pi(A_t | o_t)
        dist = self.dist(actor_features)
       	...
    ...
    def evaluate_actions(self, inputs, states, masks, action):
        ...
        # value: V(o_t)
        # action_log_probs: \log(\pi(a_t | o_t))
        # dist_entropy: Entropy(\pi(A_t | o_t))
        # states:RNN的hidden states
        return value, action_log_probs, dist_entropy, states

表示parametric policy的actor_critic变量对应的类为Policy,基本上就是包含actor和critic的神经网络。如上所示,其主要函数为actevaluate_actions,前者用于求action,后者用于从网络中获取参数更新要用到的信息(e.g., V(ot),logπ(at|ot)V(ot),log⁡π(at|ot))。


rollouts

# PPO_ROOT/storage.py
class RolloutStorage(object):
    ...
    # 保存交互信息
    # current_obs: o_{t+1}
    # state: RNN中的hidden state,
    # action: a_t
    # action_log_prob: \pi(a_t | o_t)
    # value_pred: V(o_t)
    # reward: r_t
    # mask: 标识o_{t+1}时刻是否通过gym.reset()得到,是的话为0,否的话为1
    def insert(self, current_obs, state, action, action_log_prob, value_pred, reward, mask):
        ...
    ...
    # 计算return
    # 可以选择使用GAE(GAE+V)还是MC来计算return
    def compute_returns(self, next_value, use_gae, gamma, tau):
        if use_gae:
            ...
        else:
            ...
    # 将num_steps * num_processes份transitions随机分成num_mini_batch个mini_batch,每次返回一个mini_batch
    def feed_forward_generator(self, advantages, num_mini_batch):
        ...
    # 考虑到RNN的顺序有关性质,以process为单位将num_steps * num_processes份transitions随机分成num_mini_batch个mini_batch,每次返回一个mini_batch
    def recurrent_generator(self, advantages, num_mini_batch):
        ...

用于储存trajectory的rollouts变量对应的类为RolloutStorage,主要功能是存和取,其对应的成员函数分别是insertxxx_generator。还有一个compute_returns函数用于辅助计算更新参数所需的return。


agent

# PPO_ROOT/algo/ppo.py
class PPO(object):
    ...
    def update(self, rollouts):
    	# \hat{A)_t in PPO paper(return的计算可能用的GAE,也可能用的MC)
        # 并且对advantage做normalization
        advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
        advantages = (advantages - advantages.mean()) / (
            advantages.std() + 1e-5)
        ...
        for e in range(self.ppo_epoch):
            ...
            for sample in data_generator:
                ...
                # r_t(\theta) in the PPO paper
                ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
                surr1 = ratio * adv_targ
                surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
                                           1.0 + self.clip_param) * adv_targ
                # -L_t^{CLIP}(\theta) in the PPO paper
                action_loss = -torch.min(surr1, surr2).mean()
               	# L_t^{VF}(\theta) in the PPO paper
                value_loss = F.mse_loss(return_batch, values)
                self.optimizer.zero_grad()
                # -L_t^{CLIP+VF+S}(\theta) in the PPO paper(因为我们想做minization,而lr是正的,因此要加负号)
                (value_loss * self.value_loss_coef + action_loss -
                 dist_entropy * self.entropy_coef).backward()
                nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
                                         self.max_grad_norm)
                self.optimizer.step()
                ...
        ...

表示RL algorithm的变量agent对应的类是PPO,由上可见,就是利用rollouts中的信息对actor_critic进行参数更新,更新的方式参照PPO的论文。

Summary

至此,这份PPO代码的核心部分已经描述完毕。可见,代码的实现跟论文所说的还是有很多细节上的差异的,比如说对observation和reward的normalization,对return的计算,对advantage的normalization以及对gradient的clip等等。

Dynadot Hostwinds
赞(11)
版权声明:本文采用知识共享 署名4.0国际许可协议 [BY-NC-SA] 进行授权
文章名称:《PyTorch PPO 源码解读 (pytorch-a2c-ppo-acktr-gail)
文章链接:https://oldtang.com/6768.html
本站资源仅供个人学习交流,请于下载后24小时内删除,不允许用于商业用途,否则法律问题自行承担。