VOOZH about

URL: https://deepwiki.com/inclusionAI/AReaL/3.2-fsdpengine

⇱ FSDPEngine | inclusionAI/AReaL | DeepWiki


Loading...
Last indexed: 7 May 2026 (2e12c1)
Menu

FSDPEngine

The FSDPEngine is AReaL's PyTorch FSDP2-based training backend. It provides distributed training with flexible 3D parallelism (DP x SP x TP), tree attention optimization for shared-prefix efficiency, and seamless weight synchronization with remote inference engines for reinforcement learning workflows.

Scope: This page covers the FSDP2 training engine implementation, including initialization, parallelism strategies, training loops, tree attention support, and weight synchronization.

Sources: areal/engine/fsdp_engine.py1-1249


Architecture Overview

FSDPEngine implements the TrainEngine interface and orchestrates distributed training using PyTorch's FSDP2 (Fully Sharded Data Parallel). The engine manages process groups, model sharding, and optimizer states, while coordinating with remote inference engines.

Class Structure


Sources: areal/engine/fsdp_engine.py218-222 areal/engine/fsdp_engine.py140-170 areal/engine/fsdp_utils/parallel.py15-100


Initialization Flow

The engine initialization follows a multi-stage process: process group creation, model loading, FSDP sharding, and optimizer setup.


Process Group Creation

FSDP uses a 3D device mesh for data, sequence, and tensor parallelism. The ParallelHelper is used to resolve the mesh dimensions from the ParallelStrategy.


Sources: areal/engine/fsdp_engine.py223-253 areal/engine/fsdp_utils/parallel.py15-100

Model Sharding Strategy

The parallelize_model() function applies FSDP2 with multi-dimensional parallelism:

Parallelism TypeDimensionPurpose
Data Parallel (DP)dp_sizeShards model parameters and optimizer states across DP ranks areal/engine/fsdp_utils/parallel.py115-250
Sequence Parallel (SP)sp_sizeSplits sequence dimension using Ulysses attention areal/engine/fsdp_engine.py323-331
Tensor Parallel (TP)tp_sizeShards weight matrices across TP ranks areal/engine/fsdp_utils/parallel.py115-250

Sources: areal/engine/fsdp_utils/parallel.py115-250 areal/engine/fsdp_engine.py323-331


Parallelism Configuration

Parallel Strategy to Device Mesh


Process Groups:

Sources: areal/engine/fsdp_engine.py223-253 areal/api/cli_args.py230-260


Training Pipeline

Microbatch Processing Flow


Sources: areal/engine/fsdp_engine.py578-650 areal/engine/fsdp_engine.py537-577 areal/utils/data.py380-450


Tree Attention Support

FSDP engine integrates tree attention for efficient training with shared prefixes.

Integration Points


Sources: areal/engine/fsdp_engine.py140-170 areal/models/tree_attn/module.py97-100 areal/models/tree_attn/tree.py101-102


Weight Synchronization

FSDP engine synchronizes weights with remote inference engines using two protocols: XCCL (cross-collective) and disk-based.

Synchronization Architecture


Sources: areal/engine/fsdp_engine.py837-939 areal/engine/fsdp_engine.py941-1007


Memory Management

Per-Layer Optimizer Step

For large models using FSDP CPU offloading, the default Adam update on CPU is slow. FSDPEngine supports PerLayerOptimWrapper which groups model parameters into layers and streams optimizer states per-layer to the device for accelerated updates while maintaining a low memory footprint areal/engine/fsdp_utils/optimizer.py19-41 This is enabled via the per_layer_optim_step configuration docs/en/best_practices/handling_oom.md173-178

AnyPrecision Optimization

The engine supports AnyPrecisionAdamW, allowing for flexible precision in optimizer states (e.g., BFloat16 momentum and variance) to reduce memory overhead compared to standard FP32 states areal/engine/fsdp_utils/optimizer.py44-99

Offloading Strategies

FSDP engine supports multiple memory optimization techniques:


Sources: areal/engine/fsdp_engine.py691-721 areal/engine/fsdp_engine.py293-350 areal/utils/offload.py131-132 areal/engine/fsdp_utils/optimizer.py19-41 docs/en/best_practices/handling_oom.md167-181


Configuration Reference

MicroBatchSpec

The MicroBatchSpec controls how training data is split into smaller chunks for the forward and backward passes areal/api/cli_args.py99-139

ParameterTypeDefaultDescription
n_mbsint | None1Number of micro-batches (or minimum number if max_tokens_per_mb is set).
granularityint1Adjacent sequences grouped by this size.
max_tokens_per_mbint | NoneNoneMaximum tokens per micro-batch for each forward pass.
n_mbs_divisorint1Adjusts final n_mbs to be divisible by this value.
packing_algorithmstr"ffd"Sequence packing algorithm for allocation (ffd or kk).

Sources: areal/api/cli_args.py99-139 areal/utils/data.py18-19

NormConfig

Configuration for reward and advantage normalization during RL training areal/api/cli_args.py42-78

ParameterTypeDefaultDescription
mean_levelstr | None"batch"Level for mean normalization (batch, group, None).
std_levelstr | None"batch"Level for std normalization.
std_unbiasedboolTrueUse unbiased standard deviation computation.
group_sizeint1Size of groups for group-level normalization.

Sources: areal/api/cli_args.py42-78 areal/utils/data.py18-19