VOOZH about

URL: https://deepwiki.com/inclusionAI/AReaL/10.6-tree-training

⇱ Tree Training | inclusionAI/AReaL | DeepWiki


Loading...
Last indexed: 7 May 2026 (2e12c1)
Menu

Tree Training

Tree Training is an optimization technique in AReaL designed for efficient Reinforcement Learning on sequences that share common prefixes (e.g., multi-sample generation from the same prompt or multi-turn agentic conversations). By organizing sequences into a trie-based structure, the system avoids redundant computation of shared prefix hidden states during forward and backward passes, significantly reducing FLOPs and memory usage areal/models/tree_attn/tree.py3-8

TrieNode Structure

The core data structure for tree training is the TrieNode. It represents a compressed trie where each node contains a contiguous run of tokens shared by a set of sequences areal/models/tree_attn/tree.py41-45

AttributeDescription
tokensList of token IDs stored in the node areal/models/tree_attn/tree.py73
sequence_idsIDs of sequences passing through this node areal/models/tree_attn/tree.py74
childrenMap of child nodes keyed by the first diverging token areal/models/tree_attn/tree.py75
ancestorsList of parent nodes up to the root areal/models/tree_attn/tree.py76
start_idx / end_idxIndices in the flattened tree representation areal/models/tree_attn/tree.py71-72

Sources: <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/tree.py#L40-L120" min=40 max=120 file-path="areal/models/tree_attn/tree.py">Hii</FileRef>

Tree Training Data Flow

The transition from standard batched sequences to tree-structured training involves building a trie, packing tokens into a flat 1D tensor, and generating specialized attention masks. The build_packed_tree_batch function is the primary entry point for this transformation areal/models/tree_attn/tree.py421-430

Logical to Physical Mapping

The following diagram illustrates how the TrieNode structure maps to the actual training tensors and backend modules.


Sources: <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/tree.py#L41-L120" min=41 max=120 file-path="areal/models/tree_attn/tree.py">Hii</FileRef>, <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/tree.py#L421-L450" min=421 max=450 file-path="areal/models/tree_attn/tree.py">Hii</FileRef>

Tree Attention Implementations

AReaL supports multiple backends for tree-structured attention to handle the tree-causal dependency patterns.

1. PyTorch Flex Attention (Default)

Uses torch.nn.attention.flex_attention to apply a BlockMask generated from the trie structure areal/models/tree_attn/module_fsdp.py6-10

2. Triton Tree Attention (Experimental)

A custom Triton kernel that iterates over sparse blocks of the tree areal/models/tree_attn/triton_kernel.py180-221

Sources: <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/module_fsdp.py#L1-L149" min=1 max=149 file-path="areal/models/tree_attn/module_fsdp.py">Hii</FileRef>, <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/triton_kernel.py#L1-L230" min=1 max=230 file-path="areal/models/tree_attn/triton_kernel.py">Hii</FileRef>, <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/constants.py#L1-L19" min=1 max=19 file-path="areal/models/tree_attn/constants.py">Hii</FileRef>

Backend Support

Tree training is integrated across training engines in AReaL via specialized modules and patching mechanisms areal/models/tree_attn/module.py4-40

FSDP Engine

The FSDPEngine utilizes patch_fsdp_for_tree_training to redirect standard flash attention to the tree implementation areal/models/tree_attn/module_fsdp.py154-170

Megatron Engine

The MegatronEngine utilizes patch_bridge_for_tree_training to modify the LLMBridge layer specification areal/models/tree_attn/module_megatron.py171-199

Archon Engine

The ArchonEngine utilizes TreeAttentionWrapper and TreeAttentionMeta to manage tree-based attention within its native parallelism framework areal/models/tree_attn/module_archon.py39

Sources: <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/module_fsdp.py#L154-L172" min=154 max=172 file-path="areal/models/tree_attn/module_fsdp.py">Hii</FileRef>, <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/module_megatron.py#L35-L168" min=35 max=168 file-path="areal/models/tree_attn/module_megatron.py">Hii</FileRef>, <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/module_archon.py#L39-L40" min=39 max=40 file-path="areal/models/tree_attn/module_archon.py">Hii</FileRef>

Logprob and Entropy Computation

Since tokens are packed into a tree, calculating log probabilities for RL (e.g., for PPO or GRPO) requires traversing the trie to associate logits with their corresponding sequence positions areal/models/tree_attn/functional.py3-7

FunctionRole
_compute_internal_node_logprobsComputes logprobs for tokens within a single contiguous trie node areal/models/tree_attn/functional.py26-34
_compute_transition_logprobComputes the logprob for the first token of a child node using the last logit of the parent areal/models/tree_attn/functional.py130-137
_gather_packed_tree_logprobsOrchestrates the full traversal to return a map of sequence_id to logprob tensors areal/models/tree_attn/functional.py206-213

Sequence Reconstruction Diagram


Sources: <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/functional.py#L26-L215" min=26 max=215 file-path="areal/models/tree_attn/functional.py">Hii</FileRef>, <FileRef file-url="https://github.com/inclusionAI/AReaL/blob/2e12c19b/areal/models/tree_attn/functional.py#L130-L164" min=130 max=164 file-path="areal/models/tree_attn/functional.py">Hii</FileRef>