MergeDNA (reimplementation)
Faithful reimplementation of MergeDNA: Context-aware Genome Modeling with Dynamic Tokenization through Token Merging (Li et al., 2025, arXiv:2511.14806). The model is a hierarchical autoencoder for DNA that learns variable-length tokenization via differentiable token merging — the tokenizer and context model are trained jointly under three self-supervised objectives.
This release is the best-performing checkpoint from a single-GPU reimplementation. It is undertrained relative to the paper (≈0.24% of the paper's training tokens) and is intended for reproducibility studies and as a starting point for further pre-training or fine-tuning, not as a state-of-the-art model.
Model details
| Property | Value |
|---|---|
| Parameters | ~380 M |
d_model |
1024 |
| Attention heads | 16 |
| Layers | 4 Local Enc / 20 Latent Enc / 4 Latent Dec / 2 Local Dec |
| Local window size | 16 |
| Max sequence length | 2048 (paper uses 4096; RoPE extrapolates but untested here) |
| Vocab | 4 (A, T, C, G; N collapses to A) |
| Building blocks | RMSNorm, RoPE, SwiGLU FFN, pre-norm (LLaMA-style) |
| Precision (training) | fp32 (this checkpoint) — a faster bf16 + Triton variant exists, see source repo |
The novel architectural pieces are:
- Differentiable token merging inside local-window attention. Within each window, even/odd bipartite matching plus a DTEM-style decoupled grouping projection produces similarity scores; adjacent tokens with highest similarity merge by averaging. A source matrix tracks merge lineage so the unmerge step is exact.
- Three pre-training losses trained jointly:
MTR: full-autoencoder reconstruction (collapses to ~0 early via a trivial shortcut — expected behavior).λ · MTR(θ\{φ}): second-stage reconstruction with the tokenizer detached and a tighterK=L/2latent bottleneck, which forces the latent stack to do real work.λ = 0.25.AMTM: adaptive masked-token modeling biased toward small-merge-group ("important") tokens.
Intended use
- Feature extraction for DNA classification tasks (frozen-backbone embeddings via
forward_classify). - LoRA fine-tuning on task-specific genomic data (transcription factor binding, splice sites, promoters, etc.) — protocol below.
- Continued pre-training as a starting point for domain adaptation to a specific organism, tissue, or task family.
Not intended for clinical decision-making, variant interpretation, or any use where errors carry health consequences.
How to use
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# from your local copy of the source repo
from MergeDNA import MergeDNA, MergeDNAConfig, encode_dna_sequence
REPO = "Sharmistha-NLP/mergedna-400M"
weights_path = hf_hub_download(REPO, "model.safetensors")
cfg = MergeDNAConfig(
d_model=1024, n_heads=16,
local_enc_layers=4, latent_enc_layers=20,
latent_dec_layers=4, local_dec_layers=2,
window_size=16, max_seq_len=2048,
)
model = MergeDNA(cfg)
model.load_state_dict(load_file(weights_path), strict=True)
model.eval()
# Frozen-backbone feature extraction
seq = "ACGT" * 256 # 1024 nt example
ids = encode_dna_sequence(seq).unsqueeze(0) # [1, 1024]
with torch.no_grad():
features = model.forward_classify(ids) # mean-pooled latent embedding
For resumption of pre-training, download ckpt_best.pt instead — it carries full optimizer + scheduler + RNG state.
Sandboxed sanity test
A self-contained script that downloads the weights and runs them on a tiny in-memory random-DNA dataset — no external data needed. Useful for verifying the upload round-tripped before plugging the model into a real pipeline. Save as HF/HF-test/test_hf_model.py inside a clone of the source repo and run:
uv run python HF/HF-test/test_hf_model.py --repo-id Sharmistha-NLP/mergedna-400M
The val loss will be high — random ACGT is out-of-distribution. The point is to confirm weights load, the forward pass produces finite values, and forward_classify returns the expected (B, D) tensor.
Training
Data
- Corpus: Multi-Species Genomes (NCBI RefSeq), 849 species, ~174 B nucleotides total.
- Splits: species-disjoint — 749 species train / 50 validation / 50 test. Train is bacteria-heavy; val/test are weighted toward larger eukaryote genomes.
- Chunking: 6,200 bp chunks (6 kbp config) with 100 bp overlap on each side, randomly cropped at training time so each epoch sees a different window.
- Tokens seen during this training run: ~256 M (vs the paper's ~105 B — this is a major caveat).
Hyperparameters
- Optimizer: AdamW, β = (0.9, 0.95), weight decay 0.1
- LR: base 1e-4, cosine schedule with 10K warmup
- Batch size: 4 sequences × 2048 tokens = 8,192 tokens/step
- Steps: 32,000 (run was terminated early due to a dataset URL 404; the next 18K of the planned 50K were not executed)
- Hardware: single RTX PRO 6000 Blackwell, 96 GB
- Wall-clock: ~30 hours
Training trajectory
- Steps 0–1K:
MTRcollapses to ~0.01 (trivial autoencoder shortcut — diagnostic of the loss, not a failure). - Steps 1K–10K: Latent MTR and AMTM begin moving (1.39 → ~1.2).
- Steps 10K–30K: Latent MTR drops to 0.28 at step 31,999. Representation learning is real (verified by the linear probe and LoRA results below).
Evaluation
Verified on the GUE Mouse TF-3 task (DNABERT-2 benchmark, smallest task: 1,904 train / 239 test).
| Setup | val MCC | test MCC |
|---|---|---|
| Random baseline | 0 | 0 |
| Linear probe (frozen, sklearn LogisticRegression on mean-pooled features) | ~22 | — |
| LoRA fine-tune (rank 8, α 16, 10 epochs, lr 1e-4) | 69.46 (epoch 9) | 56.34 |
| Paper Mouse TF-3 (full 100K-step backbone) | — | 73.46 |
The LoRA val MCC is within ~4 percentage points of the paper despite this checkpoint seeing ~400× less data. The test/val gap is dominated by the 239-sample test split's variance and mild overfitting by epoch 10.
Full GUE benchmark (36 tasks): not run. Estimated 6 GPU-days at this checkpoint's throughput. See REPORT.md in the source repo for the tradeoff discussion.
Reproducing the LoRA result
uv run --extra data --extra eval python code/lora_finetune.py \
--checkpoint ckpt_best.pt \
--task-dir data/GUE/mouse/3 \
--epochs 10 --max-len 256 --batch-size 32 \
--rank 8 --lora-alpha 16 --lr 1e-4 \
--head-hidden 256
Limitations and known issues
- Undertrained. ~256 M tokens vs paper's ~105 B (≈0.24%). Expect a gap on most GUE tasks vs the published numbers.
- Max sequence length 2048 at training; RoPE allows longer at inference, but this has not been validated.
- N nucleotides become A. The byte-LUT tokenizer maps anything not A/T/C/G to 0 (= A). Assembly gaps and ambiguous bases are silently lost; downstream interpretation on N-rich regions should be done carefully.
- Bacteria-heavy training distribution. 667 of 749 training species are bacteria. Performance on plant/viral genomes (excluded from training) is unknown; on large eukaryotic genomes it is plausible but undertrained.
- No safety or bias evaluation. This is a pre-training base; downstream classifiers must do their own validation.
- Single-GPU reimplementation. Architecture follows the paper but optimization details, data ordering, and tokenization quirks may differ in subtle ways from the original release.
Files in this repo
| File | Size (approx) | Purpose |
|---|---|---|
model.safetensors |
~1.5 GB | Weights only. Use for inference and fine-tuning. |
ckpt_best.pt |
~4.5 GB | Full training state (model + optimizer + scheduler + RNG + step + cfg). Use for resuming pre-training. |
best.json |
<1 KB | Per-checkpoint validation loss history. |
README.md |
— | This card. |
Citation
If you use this checkpoint, please cite the original paper:
@article{li2025mergedna,
title = {MergeDNA: Context-aware Genome Modeling with Dynamic Tokenization through Token Merging},
author = {Li, et al.},
journal= {arXiv preprint arXiv:2511.14806},
year = {2025}
}
And acknowledge this reimplementation:
@misc{jat2026mergedna_reimpl,
title = {MergeDNA reimplementation (single-GPU)},
author = {Jat, Sharmistha},
year = {2026},
howpublished = {Hugging Face model repository},
url = {https://huggingface.co/Sharmistha-NLP/mergedna-400M}
}
License
Apache 2.0. Data is sourced from NCBI RefSeq (public domain) via the InstaDeepAI multi-species genomes dataset.
