VOOZH about

URL: https://deepwiki.com/inclusionAI/AReaL/8.2-fsdp-parallelism

⇱ FSDP Parallelism | inclusionAI/AReaL | DeepWiki


Loading...
Last indexed: 7 May 2026 (2e12c1)
Menu

FSDP Parallelism

Purpose and Scope

This page documents the Fully Sharded Data Parallel (FSDP) parallelism implementation in AReaL's FSDPEngine. It covers the N-dimensional parallelism strategy, device mesh construction, process group organization, and memory optimization features unique to the FSDP backend.

The FSDPEngine is a training backend implementation of the TrainEngine API areal/utils/recover.py21 designed to scale Large Language Models (LLMs) using PyTorch's native FSDP2 (Fully Sharded Data Parallel version 2) areal/engine/fsdp_utils/__init__.py24-29


Overview

The FSDPEngine leverages FSDP2 to implement flexible N-dimensional parallelism. Unlike traditional data parallelism that replicates the entire model across workers, FSDP shards model parameters, gradients, and optimizer states across devices to enable training of models that exceed single-device memory capacity.

Key Features:

Sources: areal/engine/fsdp_utils/__init__.py62-107 areal/engine/fsdp_utils/optimizer.py44-101 docs/en/best_practices/handling_oom.md167-181


N-D Parallelism Architecture

The FSDPEngine supports multiple parallelism dimensions that can be combined simultaneously. The configuration is driven by the backend string patterns (e.g., fsdp:d2c2 for 2-way DP and 2-way Context Parallel) docs/en/best_practices/handling_oom.md120-124

Parallelism Dimensions and Code Entities

The following diagram maps the logical parallelism dimensions to the internal FSDPEngine and utility abstractions.


Parallelism Dimension Details:

DimensionPurposeCommunication PatternPrimary Use Case
DPScale batch sizeAll-reduce gradientsStandard distributed training
TPShard large layersAll-reduce/all-gatherWide models (large hidden dim)
SPShard sequence lengthAll-to-all attentionLong-context training (Ulysses) docs/en/best_practices/handling_oom.md114-115

Sources: areal/engine/fsdp_utils/__init__.py62-107 docs/en/best_practices/handling_oom.md112-132


Device Mesh and Process Groups

Device Mesh Construction

Process groups are initialized using init_custom_process_group to allow independent communication groups between training and inference engines areal/engine/core/distributed.py79-88


Key Process Groups:

Sources: areal/engine/core/distributed.py26-139 tests/test_warmup_process_groups.py55-72


Model Parallelization Workflow

apply_fsdp2 Function

The core parallelization logic is implemented in apply_fsdp2 areal/engine/fsdp_utils/__init__.py62-67 which applies FSDP2 sharding to model modules based on the transformer_layer_cls_to_wrap policy.


Sources: areal/engine/fsdp_utils/__init__.py62-108


Memory Management and Optimization

Per-Layer Optimizer Step

When parameters are offloaded to CPU, standard optimizer updates are slow. AReaL provides PerLayerOptimWrapper areal/engine/fsdp_utils/optimizer.py19-20 which groups parameters by transformer layer and streams optimizer states/parameters to the device one layer at a time tests/test_per_layer_optim_step.py124-145

AnyPrecisionAdamW

The AnyPrecisionAdamW optimizer areal/engine/fsdp_utils/optimizer.py44-60 allows using bfloat16 for momentum (exp_avg) and variance (exp_avg_sq) buffers areal/engine/fsdp_utils/optimizer.py144-147 It supports optional Kahan summation (compensation) to offset precision reduction during weight updates areal/engine/fsdp_utils/optimizer.py150-153

Memory-Efficient Loading

fsdp2_load_full_state_dict enables loading large checkpoints by broadcasting from rank 0 areal/engine/fsdp_utils/__init__.py110-123 It supports materializing models from meta tensors via to_empty() before broadcasting weights areal/engine/fsdp_utils/__init__.py133-137

Sources: areal/engine/fsdp_utils/optimizer.py44-182 areal/engine/fsdp_utils/__init__.py110-161 tests/test_per_layer_optim_step.py188-191


Data Flow and Recovery

Checkpointing and Recovery

The RecoverInfo and RecoverHandler classes manage distributed state saving. Because dataloader states differ across ranks, RecoverInfo performs an all_gather_object to collect state from all ranks before rank 0 dumps to disk areal/utils/recover.py57-66


Sources: areal/utils/recover.py41-94 areal/utils/saver.py122-159

Weight Synchronization

Weight updates between training and inference engines can be performed via NCCL (XCCL) or disk docs/zh/best_practices/handling_oom.md200-213 For FSDP, full state dicts are reconstructed using StateDictOptions with full_state_dict=True areal/engine/fsdp_utils/__init__.py141-147

Sources: areal/engine/fsdp_utils/__init__.py110-161 areal/utils/recover.py15-24