前面两节讲完了critic、actor以及缓冲区的设计,下面就到了actor和critic的损失函数的环节了。对于神经网络来说,最重要的就是计算损失函数进行反向传播更新参数了。在计算损失函数之前,需要有batch的数据,所以上一节也把缓冲区的设计完成。这节完成损失函数的设计和整个PPO算法的更新。

价值函数的损失函数

价值函数的损失计算较为简单,我们一般使用当前obs的价值函数与reward-to-go的值进行差值平方取平均运算,即

    def compute_loss_v(data):
        obs, ret = data['obs'], data['ret']
        ppo_logger.log("obs={},ret={},loss_v={}".format(obs,ret,((ac.v(obs) - ret) ** 2).mean()))
        return ((ac.v(obs) - ret) ** 2).mean()
`

策略函数的损失函数

PPO的策略函数对应有两种形式:
(1)PPO-Penalty,对应的损失函数为
在这里插入图片描述

其在损失函数中使用KL散度来作为损失函数中的惩罚项,而不是像CLIP形式那样进行严格的限制,这种惩罚形式类似TRPO的更新。
(2)PPO-Clip,其损失函数为
在这里插入图片描述
它使用一个截断系数,来限制pi与pi old不会差别过大,因为pi的过大更新会导致难以收敛,同时,PPO算法为了提高数据的利用率,在新旧策略的更新上使用了相同的数据,虽然提高了样本利用率,但是由重要性采样可知,如果想对多个策略使用同样的数据进行参数更新,那么多个策略必须是相同的策略分布区间,所以PPO的clip达到了限制新旧策略差别过大的影响,即减小了由于策略分布不同而导致的更新错误。

在下面的实现中,我们使用Clip形式。这也是大多数算法库采用的形式。

ppo的策略函数的损失函数相对较为复杂,我们直接根据PPO的论文提供的损失函数来编写代码。
在这里插入图片描述

虽然PPO的原文使用的是下面这个较为复杂的公式
在这里插入图片描述
即包含clip的损失,价值函数的损失和熵的损失,但是我们可以直接使用CLIP的损失即可,在达到近似的性能下计算开销更小。

在编写代码的时候,我们将CLIP 的公式稍微做一下简化,
在这里插入图片描述

在代码实现中,我们同时实现了KL散度,但是不是作为损失函数的计算值,而是辅助更新作用,当当前的KL值大于1.5倍的target kl时,直接停止更新当前策略的参数,这对于限制策略的更新幅度同样起了作用。

 def compute_loss_pi(data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']

        # Policy loss
        pi, logp = ac.pi(obs, act)

        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, pi_info

下面,就只剩下PPO算法的更新过程了。

PPO更新策略参数

我们将根据这个算法流程来编写程序
在这里插入图片描述

   def update():
        data = buf.get()

        pi_l_old, pi_info_old = compute_loss_pi(data)
        #ppo_logger.log("pi_l_old={},pi_info_old={}".format(pi_l_old,pi_info_old))

        pi_l_old = pi_l_old.item()
        v_l_old = compute_loss_v(data).item()
        #ppo_logger.log("pi_l_old={},v_l_old={}".format(pi_l_old,v_l_old))

        # Train policy with multiple steps of gradient descent
        for i in range(train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, pi_info = compute_loss_pi(data)
            #ppo_logger.log("loss_pi={},pi_info={}".format(loss_pi,pi_info))

            kl = mpi_avg(pi_info['kl'])
            if kl > 1.5 * target_kl:
                logger.log('Early stopping at step %d due to reaching max kl.' % i)
                break
            loss_pi.backward()
            mpi_avg_grads(ac.pi)  # average grads across MPI processes
            pi_optimizer.step()

        logger.store(StopIter=i)

        # Value function learning
        for i in range(train_v_iters):
            vf_optimizer.zero_grad()
            loss_v = compute_loss_v(data)
            ppo_logger.log("loss_v={}".format(loss_v))
            loss_v.backward()
            mpi_avg_grads(ac.v)  # average grads across MPI processes
            vf_optimizer.step()

        # Log changes from update
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        logger.store(LossPi=pi_l_old, LossV=v_l_old,
                     KL=kl, Entropy=ent, ClipFrac=cf,
                     DeltaLossPi=(loss_pi.item() - pi_l_old),
                     DeltaLossV=(loss_v.item() - v_l_old))

    # Prepare for interaction with environment
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0
   # ppo_logger.log("o={},ep_ret={},ep_len={}".format(o,ep_ret,ep_len))

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for t in range(local_steps_per_epoch):
            a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
          #  ppo_logger.log("a={},v={},logp={}".format(a,v,logp))

           # print('a={}'.format(a))
            next_o, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            # save and log
           # print(Back.RED+'o={},\na={},\nr={},\nv={},\nlogp={}'.format(o,a,r,v,logp))
            buf.store(o, a, r, v, logp)
            logger.store(VVals=v)

            # Update obs (critical!)
            o = next_o

            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t == local_steps_per_epoch - 1

            if terminal or epoch_ended:
                if epoch_ended and not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' % ep_len, flush=True)
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
                buf.finish_path(v)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, ep_ret, ep_len = env.reset(), 0, 0

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, None)

        # Perform PPO update!
        update()

以上就是PPO算法的全部了,全部代码可以在这里找到,这里面还包含了一些其他的相关资料。
在这里插入图片描述

下一次,我们就要从头开始编写环境了。


猜你想看: