VOOZH about

URL: https://huggingface.co/mlnomad/gelu-d12-chinchilla-261M

⇱ mlnomad/gelu-d12-chinchilla-261M · Hugging Face


GELU d=12 Chinchilla (261M)

A 261M-parameter nanochat-architecture GPT with GELU activation in the MLP. Trained on English C4 to Chinchilla-optimal token budget (20× params ≈ 5.22 B tokens) on a single TPU v6e-8.

This repository hosts the seed 0 checkpoint (model + optimizer state, resumable). Training was also run for seeds 1 and 2 under identical compute to quantify seed variance — summary below.

Final result — seed 0 (this checkpoint)

Metric Value
Final loss 3.0745
Smooth final loss 3.1106
Tokens 5.22 B
Steps 19,920
Wall time 2.06 h
Throughput 703 K tok/s

3-seed variance (same arch, same data, same compute)

Seed Final loss Smooth final Wall time Throughput
0 (this repo) 3.0745 3.1106 2.06 h 703 K tok/s
1 3.1355 3.1097 2.02 h 717 K tok/s
2 3.0765 3.1261 2.02 h 717 K tok/s
mean ± std 3.0955 ± 0.028 3.1155 ± 0.008 2.03 ± 0.02 h 712 ± 7 K tok/s

Architecture

Full nanochat stack with GELU MLP — everything else identical to the default nanochat GPT:

  • RoPE (base 100K), no learned positional embeddings
  • MHA (n_head = n_kv_head = 12, head_dim = 64; GQA-capable but all d=12 models use full MHA)
  • QK norm with 1.2× scaling
  • Sliding window attention with "SSSL" pattern
  • Tied embeddings (wte ↔ lm_head)
  • Parameterless RMSNorm post-embedding and per block
  • Value embeddings (ResFormer-style, alternating layers)
  • Per-layer learnable residual scalars
  • Smear (learnable bigram gate on first 24 dims of token embedding)
  • Backout (subtract mid-layer residual)
  • Logit soft-capping via tanh(x/15)·15
  • No biases in any Linear (attention Q/K/V/proj, MLP fc/proj)
  • MLP FFN: Linear(n → 4n) → jax.nn.gelu → Linear(4n → n)
    • Applied via the --mlp gelu flag in scripts/train_d12_chinchilla.py which monkey-patches MLP.__call__ in flaxchat.gpt

Training

Architecture Nanochat-style GPT with GELU MLP
Parameters 261,096,338
Config d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL window
Data allenai/c4 (English split, streamed)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, β=(0.9, 0.999), wd=0.01, global-norm grad clip 1.0
LR schedule warmup-cosine-decay, peak 0.01, warmup 500, end 5% of peak
Batch 32/device × 8 devices = 256 global (262 K tokens/step)
Seq length 1024
Tokens Chinchilla 20× params ≈ 5.22 B tokens (19,920 steps)
Hardware TPU v6e-8 (TRC), europe-west4-a
Seed 0 (seeds 1, 2 trained under identical config for variance table above)
Wandb irf-sic/flaxchat — gelu-d12-chinchilla-lr0.01-seed0

The optimizer state is preserved → checkpoint is resumable for further training.

Contents

This repository contains the full Orbax checkpoint (model weights + AdamW optimizer state), the frozen config, and code snapshots sufficient to load the model:

.
├── 19920/ # final Orbax checkpoint (~2.86 GB total)
│ ├── _CHECKPOINT_METADATA
│ ├── metadata/ # JSON: {"final_loss": 3.07, "smooth": 3.11}
│ ├── model/ # nnx.Param state — architecture weights
│ └── optimizer/ # optax AdamW state (Adam m/v moments + step + LR state)
├── config.json
├── README.md
└── code/ # snapshots of flaxchat code (source of truth: github)
 ├── gpt.py
 ├── checkpoint.py
 ├── config.py
 ├── train_d12_chinchilla.py # the exact training script (use --mlp gelu)
 └── load_model.py

Loading

git clone https://github.com/mlnomadpy/flaxchat && cd flaxchat
pixi install # or pip install with the deps in code/load_model.py
huggingface-cli download mlnomad/gelu-d12-chinchilla-261M --local-dir ./hf_model
# Inside your flaxchat clone — adapt code/load_model.py for GELU:
from flax import nnx
import jax, jax.numpy as jnp
from flaxchat.gpt import GPT, GPTConfig, MLP
from flaxchat.checkpoint import restore_model_from_checkpoint

# Patch MLP to use GELU instead of ReLU² (matches training)
_orig = MLP.__call__
def _gelu_call(self, x):
 x = self.c_fc(x)
 x = jax.nn.gelu(x)
 x = self.c_proj(x)
 return x
MLP.__call__ = _gelu_call

config = GPTConfig(
 sequence_len=1024, vocab_size=32768,
 n_layer=12, n_head=12, n_kv_head=12, n_embd=768,
 window_pattern="SSSL", tie_embeddings=True,
)
model = GPT(config, rngs=nnx.Rngs(0))
meta = restore_model_from_checkpoint(model, "./hf_model/19920")
print(f"Loaded — final loss: {meta['final_loss']:.4f}") # expect 3.0745

Weights and optimizer state

Both the model parameters and the full AdamW optimizer state (Adam m/v moments, step counter, LR schedule state) are stored at:

The checkpoint format is Orbax (OCDBT PyTree). flaxchat.checkpoint.restore_model_from_checkpoint(model, ckpt_path, optimizer=optimizer) restores both in one call.

License

Apache 2.0.

Downloads last month
6
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for mlnomad/gelu-d12-chinchilla-261M

Finetunes
1 model

Dataset used to train mlnomad/gelu-d12-chinchilla-261M