VOOZH about

URL: https://deepwiki.com/inclusionAI/AReaL/7.2-ppo-implementation

⇱ PPO Implementation | inclusionAI/AReaL | DeepWiki


Loading...
Last indexed: 7 May 2026 (2e12c1)
Menu

PPO Implementation

This page documents the Proximal Policy Optimization (PPO) algorithm implementation in AReaL, including configuration options, loss computation, advantage estimation, and actor-critic coordination. PPO is a policy gradient method that uses a clipped surrogate objective to prevent excessively large policy updates.

Scope: This page covers PPO-specific configuration (PPOActorConfig, PPOCriticConfig) and algorithmic components. For general trainer orchestration, see 7.4 Trainer Orchestration For asynchronous off-policy training features, see 7.5 Asynchronous Training For group-based variants like GRPO, see 7.3 GRPO and Variants

PPO Overview in AReaL

PPO in AReaL is implemented as a configuration of the general RL training system. The key distinction between PPO and GRPO is that PPO uses a learned critic model (value function) for advantage estimation, while GRPO uses Monte Carlo returns or group-based baselines.

The PPO implementation is split into two primary classes:

  1. PPOActor: Manages policy log-probability computation, advantage estimation, and the policy gradient update areal/trainer/ppo/actor.py43-68
  2. PPOCritic: Manages value function predictions and the value loss update areal/trainer/ppo/critic.py25-40

The PPOTrainer orchestrates these components, initializing the actor, critic, and reference models based on the provided PPOConfig. It handles the logic for offloading models (ref, teacher, critic) to save VRAM and manages the transition between rollout and training phases areal/trainer/rl_trainer.py105-189

Sources: areal/trainer/ppo/actor.py43-68 areal/trainer/ppo/critic.py25-40 areal/trainer/rl_trainer.py105-189

System Architecture and Data Flow

The following diagram bridges the mathematical PPO concepts to the specific code entities in the AReaL codebase.

PPO Logic to Code Mapping


Sources: areal/trainer/ppo/actor.py43-162 areal/trainer/ppo/critic.py25-149 areal/trainer/ppo/stats.py10-38

Configuration Dataclasses

PPOActorConfig

The PPOActorConfig class adds PPO-specific parameters for the policy (actor) model. During initialization, the PPOActor logs its configuration to distinguish between standard on-policy PPO and decoupled off-policy PPO areal/trainer/ppo/actor.py71-124

Key parameters and their roles in PPOActor:

ParameterCode UsagePurpose
recompute_logprob_compute_advantagesIf true, overwrites rollout logprobs with fresh ones areal/trainer/ppo/actor.py178-183
use_decoupled_loss_compute_advantagesSwitches to off-policy logic using behavior policies areal/trainer/ppo/actor.py184-188
kl_ctl__init__Coefficient for KL penalty areal/trainer/ppo/actor.py52
adv_norm__init__Normalization instance for advantages areal/trainer/ppo/actor.py55
prox_logp_method_log_configurationMethod to compute proximal policy (π_prox) areal/trainer/ppo/actor.py93-101

Sources: areal/trainer/ppo/actor.py44-124 areal/api/cli_args.py27

PPOCriticConfig

The PPOCriticConfig configures the value function model. The PPOCritic uses this to manage the value loss clipping and normalization areal/trainer/ppo/critic.py26-28

ParameterDefaultDescription
ppo_n_minibatches4Number of minibatches for the critic update areal/trainer/ppo/critic.py63
eps_clip0.2Clipping factor for the value loss areal/trainer/ppo/critic.py108

Sources: areal/trainer/ppo/critic.py25-149 areal/api/cli_args.py29

Loss Computation Details

Actor Loss (Policy Gradient)

The actor loss is computed in ppo_actor_loss_fn (called via the training engine). It handles:

  1. Ratio Calculation: $r_t(\theta) = \exp(\log \pi_\theta - \log \pi_{prox})$.
  2. Clipping: Standard PPO clipping using eps_clip.
  3. KL Penalty: Using the configured KLEstimator areal/trainer/ppo/actor.py52-53

Proximal Policy Approximation: In off-policy scenarios, AReaL can approximate the proximal policy log-probs using methods like loglinear interpolation between behavior and current versions. This is logged during initialization in _log_configuration areal/trainer/ppo/actor.py92-101 Supported methods include PROX_LOGP_METHOD_RECOMPUTE, PROX_LOGP_METHOD_LOGLINEAR, and PROX_LOGP_METHOD_METRICS areal/utils/constants.py22-24

Critic Loss (Value Function)

The critic loss is computed in ppo_critic_loss_fn areal/trainer/ppo/critic.py117-123 It supports clipped value updates to prevent the value function from changing too rapidly:


Sources: areal/trainer/ppo/critic.py105-149 areal/trainer/ppo/actor.py52-53 areal/utils/constants.py22-24

Advantage Estimation and Reward Processing

The PPOActor._compute_advantages method performs the following sequence areal/trainer/ppo/actor.py145-220:

  1. Overlong Penalty: Applies penalties if the response exceeds overlong_tokens areal/trainer/ppo/actor.py153-164
  2. Reward Transformation: Applies reward_bias, reward_scaling, and reward_clip areal/trainer/ppo/actor.py167-171
  3. KL Reward Integration: Combines the environment reward with the KL penalty. If ref_logp is missing, it defaults to zero areal/trainer/ppo/actor.py189-192
  4. GAE Calculation: Uses discount and gae_lambda to compute advantages and returns areal/trainer/ppo/actor.py208-220
  5. Normalization: If adv_norm is enabled, advantages are normalized across the batch or group using the Normalization utility areal/trainer/ppo/actor.py55

Sources: areal/trainer/ppo/actor.py145-220 areal/api/cli_args.py27

On-Policy Distillation (KDRL)

AReaL supports on-policy knowledge distillation combined with reinforcement learning (KDRL) docs/en/algorithms/distillation.md1-13 In this framework, a teacher model provides guidance on trajectories sampled from the student's own policy.

Joint Loss Strategy

When a teacher is configured, the trainer uses a Joint Loss strategy: $$J_{KDRL}(\theta) = J_{GRPO/PPO}(\theta) - \beta D_{KL}(\pi_\theta \parallel \pi_T)$$ This is implemented in areal/trainer/ppo/actor.py by calculating the reward as teacher_logp - logprobs and applying a negative coefficient to the loss docs/en/algorithms/distillation.md52-78

Teacher Configuration: The teacher is configured with specific loss weights for RL and distillation examples/distillation/gsm8k_grpo_distill.yaml90-103:

  • rl_loss_weight: Weight for the standard RL objective.
  • distill_loss_weight: Weight for the Reverse KL distillation penalty.

Sources: docs/en/algorithms/distillation.md1-78 examples/distillation/gsm8k_grpo_distill.yaml90-103 areal/trainer/rl_trainer.py183-189

Statistics and Monitoring

AReaL uses a sophisticated statistics tracking system to monitor PPO training. A critical component is infer_token_denominator, which ensures consistent token counts even when using Context Parallelism by prioritizing metadata like attention_mask or cu_seqlens over raw tensor shapes areal/trainer/ppo/stats.py10-20

Statistics Data Flow


Sources: areal/trainer/ppo/stats.py10-38 areal/trainer/ppo/critic.py126-148 tests/test_ppo_stats.py11-52

Integration with Training Engines

Both PPOActor and PPOCritic interact with a TrainEngine (e.g., FSDP, Archon).

Sources: areal/trainer/ppo/actor.py33-37 areal/trainer/ppo/critic.py60-73 areal/trainer/ppo/critic.py17-21