VOOZH about

URL: https://deepwiki.com/inclusionAI/AReaL/7.7-advanced-algorithms

⇱ Advanced Algorithms | inclusionAI/AReaL | DeepWiki


Loading...
Last indexed: 7 May 2026 (2e12c1)
Menu

Advanced Algorithms

This page details the advanced algorithmic techniques implemented in AReaL, including M2PO for stable off-policy training, on-policy distillation (KDRL), proximal log-probability approximations, and sophisticated normalization strategies. These techniques extend standard RL algorithms like PPO and GRPO to improve stability, efficiency, and knowledge transfer.

Second-Moment Trust Policy Optimization (M2PO)

M2PO is an RL method designed to achieve stable off-policy training even when data is stale by several hundred model updates docs/en/algorithms/m2po.md9-13 It matches on-policy performance by constraining the second moment of importance weights, which suppresses extreme outliers while preserving informative updates docs/en/algorithms/m2po.md11-13

Implementation Logic

  1. Compute Second Momentum: Calculate $\hat{M_2} = \frac{1}{N}\sum (\log \frac{\pi_\theta}{\pi_{behav}})^2$ docs/en/algorithms/m2po.md13-17
  2. Masking: Generate a mask $M$ based on the second momentum threshold docs/en/algorithms/m2po.md19-21
  3. Optimization: Apply the mask to the objective: $J_{\text{M2PO}}(\theta) = \mathbb{E}[M \cdot \frac{\pi_\theta}{\pi_{old}} \cdot A]$ docs/en/algorithms/m2po.md24-29

The core parameter for this algorithm is actor.m2_threshold areal/trainer/ppo/actor.py66 docs/en/algorithms/m2po.md43-44

M2PO Data Flow

The following diagram illustrates how the M2PO mask is integrated into the actor loss calculation.

M2PO Algorithm Integration


Sources: areal/trainer/ppo/actor.py43-66 docs/en/algorithms/m2po.md13-30 docs/en/algorithms/m2po.md43-44


On-Policy Distillation (KDRL)

On-policy distillation trains a student policy to mimic a teacher using trajectories sampled from the student's own policy docs/en/algorithms/distillation.md5-8 This reduces distribution mismatch (exposure bias) compared to standard SFT docs/en/algorithms/distillation.md30-34

Core Mechanisms

  • Reverse KL (RKL): The student minimizes $D_{KL}(\pi_\theta \parallel \pi_T)$, which is equivalent to REINFORCE where the reward is the log-ratio of teacher to student probabilities: $R = \log \pi_T - \log \pi_\theta$ docs/en/algorithms/distillation.md36-48
  • Joint Loss (KDRL): Combines GRPO/PPO objective with an auxiliary KL penalty: $J_{KDRL} = J_{RL} - \beta D_{KL}(\pi_\theta \parallel \pi_T)$ docs/en/algorithms/distillation.md61-66

Implementation Details

The PPOActor handles the distillation logic within the loss function. If rl_loss_weight is set to 0, the implementation estimates the RKL gradient using importance sampling docs/en/algorithms/distillation.md52-55 In the joint loss case (rl_loss_weight > 0), the RKL is treated as a direct penalty: loss = rl_loss_weight * loss + distill_loss_weight * rkl_penalty docs/en/algorithms/distillation.md72-77

Configuration Reference

ParameterDescription
teacher.pathModel path for the teacher LLM examples/distillation/gsm8k_grpo_distill.yaml96
teacher.rl_loss_weightWeight for the RL reward loss examples/distillation/gsm8k_grpo_distill.yaml92
teacher.distill_loss_weightWeight for the RKL distillation penalty examples/distillation/gsm8k_grpo_distill.yaml93

Sources: docs/en/algorithms/distillation.md1-78 examples/distillation/gsm8k_grpo_distill.yaml90-103


Proximal Log-Probability Approximation

To handle off-policyness in asynchronous RL, AReaL implements approximations for the log-probability of the "proximal" policy (a version between the behavior policy and current policy). This is critical for maintaining stability when the training model is several versions ahead of the rollout model examples/experimental/prox_approx/gsm8k_grpo_prox_approx.yaml27-82

The function compute_prox_logp_approximations provides several methods for this areal/trainer/ppo/actor.py10-26 tests/test_prox_approx.py10-18:

MethodImplementationLogic
loglinearLog-linear interpolationInterpolates between old and current log-probs based on version distance tests/test_prox_approx.py24-44
linearArithmetic meanInterpolates in probability space (p-space) tests/test_prox_approx.py109-127
rolloutIdentityReturns the behavior log-probability unchanged tests/test_prox_approx.py46-63

Alpha Calculation

The interpolation factor $\alpha$ represents the relative distance of the proximal version from the behavior version. It is calculated as: $\alpha = \text{clamp}(\frac{v_{proximal} - v_{behave}}{v_{current} - v_{behave}}, 0, 1)$ tests/test_prox_approx.py39-40 tests/test_prox_approx.py65-82

The mode is configured via actor.prox_logp_method areal/trainer/ppo/actor.py92-101 which supports recompute (ground truth), loglinear (recommended), and metrics (computes all for evaluation) examples/experimental/prox_approx/gsm8k_grpo_prox_approx_eval.yaml82

Sources: areal/trainer/ppo/actor.py80-113 tests/test_prox_approx.py1-171 examples/experimental/prox_approx/gsm8k_grpo_prox_approx.yaml81-82


Soft Adaptive Policy Optimization (SAPO)

SAPO is supported as an alternative loss function that uses adaptive temperature scaling for positive and negative samples areal/utils/functional.py36

Configuration Parameters

SAPO is enabled in the actor config:

Note: SAPO typically requires use_decoupled_loss to be disabled for standard behavior.

Sources: areal/trainer/ppo/actor.py36 areal/utils/functional.py33-37


Advantage Normalization Strategies

AReaL provides a flexible NormConfig for reward and advantage normalization, supporting both batch-level and group-level (for GRPO) statistics tests/test_adv_norm_config.py7-12 examples/experimental/prox_approx/gsm8k_grpo_prox_approx.yaml73-79

Configuration Reference

The NormConfig class handles the following parameters tests/test_adv_norm_config.py15-42:

AttributeDescription
mean_levelScope for mean calculation (batch, group, or null) tests/test_adv_norm_config.py37
std_levelScope for standard deviation calculation (batch, group, or null) tests/test_adv_norm_config.py38
group_sizeSize of the group for group level normalization tests/test_adv_norm_config.py39
mean_leave1outWhether to use leave-one-out mean estimation tests/test_adv_norm_config.py40
std_unbiasedWhether to use unbiased variance estimation tests/test_adv_norm_config.py41

Implementation: PPOActor & PPOCritic

The PPOActor and PPOCritic classes initialize these normalization modules during construction areal/trainer/ppo/actor.py55-58 They are applied to raw rewards and advantages before loss computation areal/trainer/ppo/actor.py172-173

Normalization Logic Flow


Sources: areal/trainer/ppo/actor.py55-58 tests/test_adv_norm_config.py1-188 areal/api/cli_args.py7


Token Denominator Inference

In distributed training, especially with context parallelism, tensors like loss_mask or model outputs may be sliced. AReaL uses infer_token_denominator to ensure consistent statistics across different sharding strategies areal/trainer/ppo/stats.py10-20

The function prioritizes metadata from the micro-batch to reconstruct the full logical sequence mask areal/trainer/ppo/stats.py23-38:

  1. Attention Mask: If present, used to define valid tokens areal/trainer/ppo/stats.py23-25
  2. Cumulative Sequence Lengths: Used for packed sequences areal/trainer/ppo/stats.py27-29
  3. Input IDs: Used as a fallback if the shape matches the target tensor areal/trainer/ppo/stats.py31-36

Statistical Aggregation Architecture

The following diagram bridges the mathematical concept of token normalization with the code entities responsible for tracking these metrics across distributed processes.

Token Statistics and Loss Masking


Sources: areal/trainer/ppo/stats.py10-38 tests/test_ppo_stats.py11-138 areal/trainer/ppo/actor.py15-37