VOOZH about

URL: https://deepwiki.com/inclusionAI/AReaL/14.6-adding-custom-models-to-archonengine

⇱ Adding Custom Models to ArchonEngine | inclusionAI/AReaL | DeepWiki


Loading...
Last indexed: 7 May 2026 (2e12c1)
Menu

Adding Custom Models to ArchonEngine

This document provides a step-by-step guide for extending ArchonEngine with support for new HuggingFace model architectures. It covers the model registration system, required components, implementation steps, and verification procedures.

For general ArchonEngine configuration and usage patterns, see ArchonEngine. For parallelism strategies supported by ArchonEngine, see Archon Parallelism.


Purpose and Scope

ArchonEngine is AReaL's PyTorch-native training backend that provides maximum flexibility using native distributed primitives like DTensor, DeviceMesh, and FSDP2 areal/experimental/models/archon/parallel_dims.py27-30 It supports a registry-based system for adding new transformer model architectures without Megatron-Core dependencies areal/experimental/models/archon/model_spec.py98-110

This guide covers:

  • Understanding the model registration and discovery mechanism.
  • Implementing required components: args, model, state dict adapter, and parallelization.
  • Registering new model types in the global registry.
  • Integrating with Archon's 5D parallelism (DP, TP, PP, CP, EP).

Prerequisites: The target model must be available on HuggingFace with a config.json containing a model_type field, and should use a standard decoder-only transformer architecture.


Model Registration System

Discovery Flow

ArchonEngine uses a registry-based system to map HuggingFace model_type strings to Archon implementations. The system relies on ModelSpec objects stored in a global registry _MODEL_SPECS areal/experimental/models/archon/model_spec.py98-110

  1. Lookup: The engine retrieves a ModelSpec by the HF model_type via get_model_spec areal/experimental/models/archon/model_spec.py113-118
  2. Validation: It checks if the model is supported using is_supported_model areal/experimental/models/archon/model_spec.py121-123
  3. Instantiation: The engine uses the classes defined in the spec (e.g., model_class, model_args_class) to build the local model instance.

Model Discovery and Entity Association


Sources: areal/experimental/models/archon/model_spec.py86-130

ModelSpec Structure

The ModelSpec dataclass defines all components needed for a model implementation areal/experimental/models/archon/model_spec.py86-96:

FieldTypePurpose
namestrHuman-readable model name
model_classtype[nn.Module]Model architecture class (must inherit BaseArchonModel)
model_args_classtype[BaseModelArgs]Configuration/hyperparameters class
state_dict_adapter_classtype[BaseStateDictAdapter]HF ↔ Archon weight key mapper
parallelize_fnParallelizeFnFunction applying TP, CP, EP, AC, and FSDP
supported_model_typesfrozenset[str]Set of HF model_type strings (e.g., qwen2)
pipelining_fnPipeliningFn | NoneOptional function for pipeline parallelism splitting

Sources: areal/experimental/models/archon/model_spec.py26-96 areal/experimental/models/archon/base.py144-172


Component Architecture

Directory Structure

Each model implementation typically follows this structure, as seen in the Qwen3 implementation:

areal/experimental/models/archon/qwen3/
 model/
 args.py # Qwen3ModelArgs (BaseModelArgs)
 model.py # Qwen3 (BaseArchonModel)
 infra/
 parallelize.py # parallelize_qwen3 implementation

Component Dependencies

Code Entity Relationship Diagram


Sources: areal/experimental/models/archon/base.py27-172 areal/experimental/models/archon/model_spec.py86-110


Implementation Process

Step 1: Model Arguments and Configuration

The model arguments class must implement from_hf_config to convert HuggingFace PretrainedConfig into Archon-native parameters areal/experimental/models/archon/base.py48-54 This class typically includes attn_type (e.g., "sdpa", "varlen", "tree") to support AReaL's efficient RL training backends areal/experimental/models/archon/base.py30-31 For MoE models, MoEArgs are nested within the model arguments areal/experimental/models/archon/qwen3/model/args.py17-53

Step 2: Parallelization Logic

The parallelize_fn is responsible for applying distributed strategies. For complex models like Qwen3, the order of operations is critical areal/experimental/models/archon/qwen3/infra/parallelize.py102-109:

  1. Dense TP: Apply Tensor Parallelism to attention, norms, and dense layers areal/experimental/models/archon/qwen3/infra/parallelize.py134-137
  2. MoE EP+TP: Apply Expert Parallelism and MoE-specific TP areal/experimental/models/archon/qwen3/infra/parallelize.py141-150
  3. Context Parallelism: Apply Ulysses SP for long-sequence support areal/experimental/models/archon/qwen3/infra/parallelize.py153-155
  4. Activation Checkpointing (AC): Wrap modules for memory efficiency areal/experimental/models/archon/qwen3/infra/parallelize.py158-164
  5. Compilation: Apply torch.compile after AC but before FSDP areal/experimental/models/archon/qwen3/infra/parallelize.py167-169
  6. FSDP: Final wrap for Fully Sharded Data Parallelism areal/experimental/models/archon/qwen3/infra/parallelize.py172-185

Step 3: MoE Expert Parallelism (EP)

When adding MoE models, Archon supports complex EP/ETP (Expert Tensor Parallelism) configurations areal/experimental/models/archon/parallel_dims.py48-57

Step 4: Activation Checkpointing (AC)

Archon supports multiple AC modes via ActivationCheckpointConfig areal/experimental/models/archon/activation_checkpoint.py37-41:

Implementers should provide an op_sac_save_list defining which operators (e.g., aten.mm.default, areal._varlen_attn.default) are saved during forward passes areal/experimental/models/archon/qwen3/infra/parallelize.py59-83

Sources: areal/experimental/models/archon/activation_checkpoint.py37-84 areal/experimental/models/archon/qwen3/infra/parallelize.py59-185 areal/experimental/models/archon/parallel_dims.py48-65


Weight Conversion (StateDictAdapter)

The BaseStateDictAdapter handles mapping between HuggingFace parameter names and the internal Archon model structure areal/experimental/models/archon/base.py57-73

Key responsibilities:

Sources: areal/experimental/models/archon/base.py57-142 areal/experimental/models/archon/qwen3_5/model/state_dict_adapter.py21-136


Verification and Best Practices

  1. Parallelism Constraints: Always call validate_tp_constraints, validate_cp_constraints, and validate_ep_constraints at the start of your parallelize_fn areal/experimental/models/archon/qwen3/infra/parallelize.py37-41
  2. Ulysses SP Support: In the model's Attention module, implement set_cp_group to handle All-to-All communication for sequence parallelism areal/experimental/models/archon/qwen3/model/model.py161-175
  3. RMSNorm Precision: Ensure RMSNorm uses float32 for intermediate variance computations to maintain stability areal/experimental/models/archon/qwen3/model/model.py75-96
  4. Checkpoint Compatibility: Use DCPState to wrap model parts for distributed checkpointing, ensuring flatten_optimizer_state_dict=True is used to avoid stage collisions in Pipeline Parallelism areal/experimental/engine/archon_checkpoint.py86-102

Sources: areal/experimental/models/archon/qwen3/infra/parallelize.py37-41 areal/experimental/models/archon/qwen3/model/model.py75-175 areal/experimental/engine/archon_checkpoint.py86-166