![]() |
VOOZH | about |
This page describes the microbatching system used by AReaL's training engines to split large training batches into smaller micro-batches for gradient accumulation. Microbatching enables training with effective batch sizes larger than what fits in GPU memory by processing multiple micro-batches sequentially and accumulating gradients before performing an optimizer step.
The microbatching pipeline consists of:
MicroBatchSpec dataclass .MicroBatchList, MicroBatchItem) for managing micro-batch metadata .This system is shared across all training backends and supports both standard sequence packing and tree training modes.
The MicroBatchSpec dataclass defines how batches are split into micro-batches.
| Field | Type | Description |
|---|---|---|
n_mbs | `int | None` |
granularity | int | Group adjacent sequences by this size when dividing . |
max_tokens_per_mb | `int | None` |
n_mbs_divisor | int | Final micro-batch count must be divisible by this value . |
packing_algorithm | str | Algorithm for allocation: ffd (First Fit Decreasing) or kk (Karmarkar-Karp) . |
Key behaviors:
max_tokens_per_mb=None, the batch is split into exactly n_mbs micro-batches .max_tokens_per_mb is set, the system uses the configured packing_algorithm to create balanced micro-batches respecting token limits .ffd is a greedy heuristic, while kk (Largest Differencing Method) provides near-optimal balance for large-scale RL with variable sequence lengths .Sources: areal/api/cli_args.py99-147 areal/utils/seqpack.py161-187
MicroBatchList is the primary container returned by the splitting pipeline. It encapsulates the split micro-batches along with metadata needed for processing and reconstruction.
Diagram: MicroBatchList and MicroBatchItem structure
Key attributes:
forward_indices: Mapping from original batch order to micro-batch order .backward_indices: Inverse mapping to reconstruct original order .padded_mbs: List of padded micro-batch dictionaries ready for model forward .old_cu_seqlens_list: Original cumulative sequence lengths before alignment (used for context parallel unpadding) .Sources: areal/utils/data.py385-471
MicroBatchItem is a NamedTuple yielded when iterating over a MicroBatchList .
| Field | Type | Purpose |
|---|---|---|
orig_mb | dict | Original micro-batch for loss weight computation . |
padded_mb | dict | Padded micro-batch for model forward pass . |
padding_length | int | Batch-level padding added . |
old_cu_seqlens | Tensor | Pre-alignment cumulative sequence lengths . |
Sources: areal/utils/data.py367-383
The microbatching pipeline transforms a padded batch dictionary into a MicroBatchList through several stages .
Diagram: Microbatching pipeline flow from input batch to MicroBatchList
The pipeline extracts sequence lengths from the input batch's attention_mask . When granularity > 1, sequences are grouped before calculating total lengths .
The system dispatches to an allocation function via get_allocate_fn(algorithm) :
ffd_allocate): Implements First-Fit Decreasing .kk_allocate): Implements Karmarkar-Karp partitioning for superior load balance .In distributed training, all ranks must agree on the number of micro-batches to ensure consistent pipeline scheduling . The allocate_balanced_mbs_synced function performs an all_gather_object to find the maximum n_mbs across all ranks .
Sources: areal/utils/seqpack.py167-279 areal/utils/data.py244-271 areal/utils/data.py477-593
After splitting, micro-batches are padded to enable efficient tensor operations via pad_mb_list() .
When context parallelism (Ulysses) is enabled, sequences must be aligned to be divisible by the sp_size (sequence parallel size) . The ulysses_pad function pads each sequence to the nearest multiple of align_to .
The pad_to_maximum flag controls whether each micro-batch is padded independently or to a global maximum length. Padding to the maximum reduces memory fragmentation but may increase total computation .
Sources: areal/utils/data.py620-718
All training engines follow a common pattern for consuming micro-batches, either via manual loops or pipeline schedulers.
| Engine | Execution Strategy | Key Function |
|---|---|---|
| FSDPEngine | Sequential loop over micro-batches | forward_backward_batch |
| MegatronEngine | Pipeline-parallel schedule | train_batch / forward_backward_func |
| ArchonEngine | Sequential or Pipeline-parallel | ForwardBackwardRunner.run |
The ArchonEngine uses a ForwardBackwardRunner abstraction to handle micro-batches .
Diagram: ArchonEngine micro-batch execution logic
Key implementation details:
forward_backward_batch to loop over MicroBatchItem and call the model .MicroBatchList as a data iterator .compute_total_loss_weight to aggregate weights across the DP group for global normalization (in FSDP/Megatron) and (in Archon).Sources: areal/engine/fsdp_engine.py540-615 areal/engine/megatron_engine.py637-706 areal/experimental/engine/archon_engine.py147-210 areal/engine/core/__init__.py62
The microbatching pipeline supports tree training, where multiple trajectories share common prefixes.
When enable_tree_training=True:
build_packed_tree_batch organizes trajectories into a trie structure .FSDPTrainContext and ArchonTrainContext carry the trie_node through the pipeline._gather_packed_tree_logprobs and gather_packed_tree_logprobs_entropy are used to extract results from the tree structure .Sources: areal/engine/fsdp_engine.py141-168 areal/experimental/engine/archon_engine.py124-145 areal/models/tree_attn/tree.py106 areal/models/tree_attn/functional.py97-99
Refresh this wiki