VOOZH about

URL: https://huggingface.co/Sharmistha-NLP/mergedna-400M

⇱ Sharmistha-NLP/mergedna-400M · Hugging Face


MergeDNA (reimplementation)

Faithful reimplementation of MergeDNA: Context-aware Genome Modeling with Dynamic Tokenization through Token Merging (Li et al., 2025, arXiv:2511.14806). The model is a hierarchical autoencoder for DNA that learns variable-length tokenization via differentiable token merging — the tokenizer and context model are trained jointly under three self-supervised objectives.

This release is the best-performing checkpoint from a single-GPU reimplementation. It is undertrained relative to the paper (≈0.24% of the paper's training tokens) and is intended for reproducibility studies and as a starting point for further pre-training or fine-tuning, not as a state-of-the-art model.

Model details

Property Value
Parameters ~380 M
d_model 1024
Attention heads 16
Layers 4 Local Enc / 20 Latent Enc / 4 Latent Dec / 2 Local Dec
Local window size 16
Max sequence length 2048 (paper uses 4096; RoPE extrapolates but untested here)
Vocab 4 (A, T, C, G; N collapses to A)
Building blocks RMSNorm, RoPE, SwiGLU FFN, pre-norm (LLaMA-style)
Precision (training) fp32 (this checkpoint) — a faster bf16 + Triton variant exists, see source repo

The novel architectural pieces are:

  1. Differentiable token merging inside local-window attention. Within each window, even/odd bipartite matching plus a DTEM-style decoupled grouping projection produces similarity scores; adjacent tokens with highest similarity merge by averaging. A source matrix tracks merge lineage so the unmerge step is exact.
  2. Three pre-training losses trained jointly:
    • MTR: full-autoencoder reconstruction (collapses to ~0 early via a trivial shortcut — expected behavior).
    • λ · MTR(θ\{φ}): second-stage reconstruction with the tokenizer detached and a tighter K=L/2 latent bottleneck, which forces the latent stack to do real work. λ = 0.25.
    • AMTM: adaptive masked-token modeling biased toward small-merge-group ("important") tokens.

Intended use

  • Feature extraction for DNA classification tasks (frozen-backbone embeddings via forward_classify).
  • LoRA fine-tuning on task-specific genomic data (transcription factor binding, splice sites, promoters, etc.) — protocol below.
  • Continued pre-training as a starting point for domain adaptation to a specific organism, tissue, or task family.

Not intended for clinical decision-making, variant interpretation, or any use where errors carry health consequences.

How to use

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# from your local copy of the source repo
from MergeDNA import MergeDNA, MergeDNAConfig, encode_dna_sequence

REPO = "Sharmistha-NLP/mergedna-400M"

weights_path = hf_hub_download(REPO, "model.safetensors")
cfg = MergeDNAConfig(
 d_model=1024, n_heads=16,
 local_enc_layers=4, latent_enc_layers=20,
 latent_dec_layers=4, local_dec_layers=2,
 window_size=16, max_seq_len=2048,
)
model = MergeDNA(cfg)
model.load_state_dict(load_file(weights_path), strict=True)
model.eval()

# Frozen-backbone feature extraction
seq = "ACGT" * 256 # 1024 nt example
ids = encode_dna_sequence(seq).unsqueeze(0) # [1, 1024]
with torch.no_grad():
 features = model.forward_classify(ids) # mean-pooled latent embedding

For resumption of pre-training, download ckpt_best.pt instead — it carries full optimizer + scheduler + RNG state.

Sandboxed sanity test

A self-contained script that downloads the weights and runs them on a tiny in-memory random-DNA dataset — no external data needed. Useful for verifying the upload round-tripped before plugging the model into a real pipeline. Save as HF/HF-test/test_hf_model.py inside a clone of the source repo and run:

uv run python HF/HF-test/test_hf_model.py --repo-id Sharmistha-NLP/mergedna-400M

The val loss will be high — random ACGT is out-of-distribution. The point is to confirm weights load, the forward pass produces finite values, and forward_classify returns the expected (B, D) tensor.

Training

Data

  • Corpus: Multi-Species Genomes (NCBI RefSeq), 849 species, ~174 B nucleotides total.
  • Splits: species-disjoint — 749 species train / 50 validation / 50 test. Train is bacteria-heavy; val/test are weighted toward larger eukaryote genomes.
  • Chunking: 6,200 bp chunks (6 kbp config) with 100 bp overlap on each side, randomly cropped at training time so each epoch sees a different window.
  • Tokens seen during this training run: ~256 M (vs the paper's ~105 B — this is a major caveat).

Hyperparameters

  • Optimizer: AdamW, β = (0.9, 0.95), weight decay 0.1
  • LR: base 1e-4, cosine schedule with 10K warmup
  • Batch size: 4 sequences × 2048 tokens = 8,192 tokens/step
  • Steps: 32,000 (run was terminated early due to a dataset URL 404; the next 18K of the planned 50K were not executed)
  • Hardware: single RTX PRO 6000 Blackwell, 96 GB
  • Wall-clock: ~30 hours

Training trajectory

  • Steps 0–1K: MTR collapses to ~0.01 (trivial autoencoder shortcut — diagnostic of the loss, not a failure).
  • Steps 1K–10K: Latent MTR and AMTM begin moving (1.39 → ~1.2).
  • Steps 10K–30K: Latent MTR drops to 0.28 at step 31,999. Representation learning is real (verified by the linear probe and LoRA results below).

Evaluation

Verified on the GUE Mouse TF-3 task (DNABERT-2 benchmark, smallest task: 1,904 train / 239 test).

Setup val MCC test MCC
Random baseline 0 0
Linear probe (frozen, sklearn LogisticRegression on mean-pooled features) ~22
LoRA fine-tune (rank 8, α 16, 10 epochs, lr 1e-4) 69.46 (epoch 9) 56.34
Paper Mouse TF-3 (full 100K-step backbone) 73.46

The LoRA val MCC is within ~4 percentage points of the paper despite this checkpoint seeing ~400× less data. The test/val gap is dominated by the 239-sample test split's variance and mild overfitting by epoch 10.

Full GUE benchmark (36 tasks): not run. Estimated 6 GPU-days at this checkpoint's throughput. See REPORT.md in the source repo for the tradeoff discussion.

Reproducing the LoRA result

uv run --extra data --extra eval python code/lora_finetune.py \
 --checkpoint ckpt_best.pt \
 --task-dir data/GUE/mouse/3 \
 --epochs 10 --max-len 256 --batch-size 32 \
 --rank 8 --lora-alpha 16 --lr 1e-4 \
 --head-hidden 256

Limitations and known issues

  • Undertrained. ~256 M tokens vs paper's ~105 B (≈0.24%). Expect a gap on most GUE tasks vs the published numbers.
  • Max sequence length 2048 at training; RoPE allows longer at inference, but this has not been validated.
  • N nucleotides become A. The byte-LUT tokenizer maps anything not A/T/C/G to 0 (= A). Assembly gaps and ambiguous bases are silently lost; downstream interpretation on N-rich regions should be done carefully.
  • Bacteria-heavy training distribution. 667 of 749 training species are bacteria. Performance on plant/viral genomes (excluded from training) is unknown; on large eukaryotic genomes it is plausible but undertrained.
  • No safety or bias evaluation. This is a pre-training base; downstream classifiers must do their own validation.
  • Single-GPU reimplementation. Architecture follows the paper but optimization details, data ordering, and tokenization quirks may differ in subtle ways from the original release.

Files in this repo

File Size (approx) Purpose
model.safetensors ~1.5 GB Weights only. Use for inference and fine-tuning.
ckpt_best.pt ~4.5 GB Full training state (model + optimizer + scheduler + RNG + step + cfg). Use for resuming pre-training.
best.json <1 KB Per-checkpoint validation loss history.
README.md This card.

Citation

If you use this checkpoint, please cite the original paper:

@article{li2025mergedna,
 title = {MergeDNA: Context-aware Genome Modeling with Dynamic Tokenization through Token Merging},
 author = {Li, et al.},
 journal= {arXiv preprint arXiv:2511.14806},
 year = {2025}
}

And acknowledge this reimplementation:

@misc{jat2026mergedna_reimpl,
 title = {MergeDNA reimplementation (single-GPU)},
 author = {Jat, Sharmistha},
 year = {2026},
 howpublished = {Hugging Face model repository},
 url = {https://huggingface.co/Sharmistha-NLP/mergedna-400M}
}

License

Apache 2.0. Data is sourced from NCBI RefSeq (public domain) via the InstaDeepAI multi-species genomes dataset.

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
0.4B params
Tensor type
F32
·

Dataset used to train Sharmistha-NLP/mergedna-400M

Paper for Sharmistha-NLP/mergedna-400M