VOOZH about

URL: https://deepwiki.com/inclusionAI/AReaL/7.6-reference-model-and-critic

⇱ Reference Model and Critic | inclusionAI/AReaL | DeepWiki


Loading...
Last indexed: 7 May 2026 (2e12c1)
Menu

Reference Model and Critic

This page documents the reference model and critic components used in actor-critic RL algorithms like PPO and GRPO. The reference model provides KL regularization to prevent policy drift, while the critic estimates value functions for advantage computation.

Overview

In AReaL's RL implementation, auxiliary models complement the primary actor model to stabilize training:

  • Reference Model: A frozen copy of the policy (usually the initial SFT model), used to compute KL divergence penalties. This prevents the RL policy from collapsing or deviating too far from the base language capabilities.
  • Critic Model: A value function estimator used in PPO to predict expected returns ($V(s)$), enabling Generalized Advantage Estimation (GAE).
  • Teacher Model: An optional, often larger model used for on-policy knowledge distillation (KDRL), providing a target distribution for the student to mimic docs/en/algorithms/distillation.md1-14

Both reference and critic models typically share the same transformer backbone architecture as the actor but differ in their heads, training status, and hardware colocation.


Reference Model Implementation

KL Regularization and Estimation

The reference model is used by the PPOActor to compute the KL penalty. The PPOActor initializes a KLEstimator based on the kl_estimator configuration string areal/trainer/ppo/actor.py53

During the advantage computation phase in _compute_advantages, the actor retrieves ref_logp (log-probabilities from the reference model). If ref_logp is missing, it defaults to zeros areal/trainer/ppo/actor.py192-194 The KL divergence is then estimated to regularize the reward signal areal/trainer/ppo/actor.py202-204

Colocation and Scheduling

Reference models can be scheduled using different strategies defined in SchedulingStrategy. While often separate, AReaL supports specific colocation patterns to optimize GPU memory:

  1. Separation: The reference model runs on its own dedicated set of GPUs.
  2. Colocation: The reference model shares GPUs with the actor. This is configured via scheduling_strategy.type: colocation in the YAML examples/distillation/gsm8k_grpo_distill.yaml116-118 When configured, the system manages the memory footprint to allow both models to reside on the same device. The PPOTrainer handles the offloading logic for colocated components to ensure they do not exceed device memory during training steps areal/trainer/rl_trainer.py136-146

Reference Model Data Flow

The following diagram illustrates how the Reference Model interacts with the Actor during the PPO update cycle.

Title: Reference Model KL Integration


Sources: areal/trainer/ppo/actor.py128-136 areal/trainer/ppo/actor.py192-204 areal/trainer/rl_trainer.py178-181


Critic Model and Training

Architecture and Initialization

Critics in AReaL are typically value-head models. In the training orchestration, the PPOCritic class wraps a TrainEngine configured for value estimation areal/trainer/ppo/critic.py25-28 The PPOTrainer creates the critic based on the PPOCriticConfig areal/trainer/rl_trainer.py172-177

Critic Training Loop

The PPOCritic class manages the value function updates:

  • Value Computation: The compute_values method performs a forward pass in eval mode to get baseline estimates areal/trainer/ppo/critic.py32-40
  • PPO Update: The ppo_update method switches the engine to train mode, splits data into micro-batches using MicroBatchSpec, and performs gradient steps areal/trainer/ppo/critic.py44-74
  • Loss Function: The ppo_loss_fn computes the MSE loss between predicted values and targets (returns), applying a clip defined by eps_clip to stabilize training areal/trainer/ppo/critic.py105-149

Title: Critic Value Estimation and Update


Sources: areal/trainer/ppo/critic.py60-74 areal/trainer/ppo/critic.py105-123 areal/trainer/rl_trainer.py172-177


On-Policy Distillation (Teacher Model)

AReaL supports On-Policy Knowledge Distillation (RKL), where a student mimics a teacher model on trajectories sampled from the student's own policy docs/en/algorithms/distillation.md30-39

Joint Loss Strategy

When rl_loss_weight > 0 and distill_loss_weight > 0, the system optimizes a joint objective: $$J_{KDRL}(\theta) = J_{GRPO}(\theta) - \beta D_{KL}(\pi_\theta \parallel \pi_T)$$ The implementation treats RKL as a direct penalty, minimizing logprobs - teacher_logp docs/en/algorithms/distillation.md61-77 This is managed by the PPOActor during the joint loss computation by incorporating the distill_loss_weight and teacher_logp into the final loss aggregation areal/trainer/ppo/actor.py55-58


Configuration and Resource Patterns

Model Roles and Normalization

The PPOActor and PPOCritic utilize Normalization objects to stabilize signals areal/trainer/ppo/actor.py55-58

ComponentConfiguration AttributePurpose
Advantage Normadv_normStabilizes policy gradients in PPOActor areal/trainer/ppo/actor.py55
Reward Normreward_normScales external rewards before KL penalty areal/trainer/ppo/actor.py56-58
KL Estimatorkl_estimatorLogic for $D_{KL}(\pi_{\theta}

Token Denominator Inference

Stats are tracked using infer_token_denominator, which ensures consistent logging across different parallelism strategies (like Context Parallelism) by looking at attention_mask or cu_seqlens areal/trainer/ppo/stats.py10-38 This is used in both actor and critic loss functions to normalize losses correctly areal/trainer/ppo/critic.py126-130 tests/test_ppo_stats.py91-113


Advanced Proximal Policy Approximations

AReaL supports "Decoupled PPO" (off-policy), where the behavior policy ($\pi_{behave}$) might differ from the proximal policy ($\pi_{prox}$) used for the PPO clip areal/trainer/ppo/actor.py89-91

The system provides several methods to approximate the proximal log-probabilities:

Title: Proximal Log-Probability Computation


Sources: areal/trainer/ppo/actor.py80-104 areal/utils/constants.py22-25 areal/trainer/ppo/stats.py10-38