VOOZH about

URL: https://huggingface.co/mlnomad/gelu-d12-chinchilla-261M-pytorch

⇱ mlnomad/gelu-d12-chinchilla-261M-pytorch · Hugging Face


GELU d=12 Chinchilla (261M) — PyTorch / HuggingFace Transformers

A 261M-parameter nanochat-architecture GPT with GELU MLP, originally trained in JAX/Flax on a TPU v6e-8 and ported to PyTorch for easy inference via the HuggingFace transformers API.

Weights are bit-exact with the Flax checkpoint (mlnomad/gelu-d12-chinchilla-261M) — parity validated at max |Δ logits| = 1.3e-5 on CPU/fp32.

Quick start

pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
 "mlnomad/gelu-d12-chinchilla-261M-pytorch",
 trust_remote_code=True, # the model class ships in the repo
 dtype=torch.float32,
).eval()

# The model was trained with the Mistral-7B-v0.1 tokenizer (vocab 32,768)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids

with torch.no_grad():
 out = model.generate(
 ids,
 max_new_tokens=50,
 do_sample=True,
 temperature=0.8,
 top_p=0.9,
 use_cache=True,
 pad_token_id=tokenizer.eos_token_id or 0,
 )
print(tokenizer.decode(out[0], skip_special_tokens=True))

Greedy completion examples from this checkpoint:

"The meaning of life is"to live in a place where you can live. The meaning of life is to live in a

"Once upon a time,"the first thing I would do was to make a list of all the things I would like to do

Model details

Parameters 261,096,338
Architecture Nanochat-style GPT with GELU MLP (ported from JAX/Flax NNX)
Config d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window
Training data allenai/c4 (English split), 5.22 B tokens (Chinchilla 20×)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, peak LR 0.01, warmup-cosine
Hardware TPU v6e-8 (TRC), europe-west4-a
Final loss 3.0745 (smooth 3.1106)
3-seed variance mean 3.0955 ± 0.028 (smoothed 3.1155 ± 0.008)

Architecture features

Full nanochat stack, faithfully ported to PyTorch:

  • RoPE (base 100,000), split-half layout
  • MHA (n_head = n_kv_head = 12; the code supports GQA via n_kv_head < n_head, but all d=12 models use full MHA)
  • QK-norm with 1.2× scaling (after RoPE)
  • Parameterless RMSNorm (no learnable gain) post-embedding and per block
  • Sliding-window attention with "SSSL" pattern
  • Tied embeddings (lm_head = wte.T)
  • Value embeddings on alternating layers (ResFormer-style)
  • Per-layer learnable residual scalars (resid_lambdas, x0_lambdas)
  • Smear — learnable gate on first 24 dims of token embedding mixes in prev token
  • Backout — subtract mid-layer residual from late layers
  • Logit soft-cap: 15 · tanh(logits / 15)
  • No biases in any Linear
  • MLP: Linear(n → 4n) → F.gelu(approximate="tanh") → Linear(4n → n)

KV cache

The GeluGPTForCausalLM class implements a smear-aware KV cache for fast autoregressive generation. KV-cache parity vs full forward is validated at max |Δ| < 3e-5. Pass use_cache=True (the default for .generate()).

Files in this repo

.
├── config.json # HF config with auto_map pointing to the classes below
├── generation_config.json
├── model.safetensors # 1.04 GB, fp32 weights + persistent RoPE buffers
├── torch_gpt.py # pure PyTorch GPT module (GELU_GPT)
├── configuration_gelu_gpt.py # PretrainedConfig subclass
├── modeling_gelu_gpt.py # PreTrainedModel + GenerationMixin wrapper with KV cache
└── README.md

Related

Wikitext-103 evaluation

Metric Value
Wikitext-103 test loss 3.840
Wikitext-103 test PPL 46.52

Evaluated on ~330K tokens from wikitext-103 test set (model trained on C4 only — this is a zero-shot transfer metric).

License

Apache 2.0.

Downloads last month
1,515
Safetensors
Model size
0.3B params
Tensor type
F32
·

Model tree for mlnomad/gelu-d12-chinchilla-261M-pytorch

Finetuned
(1)
this model

Dataset used to train mlnomad/gelu-d12-chinchilla-261M-pytorch

Space using mlnomad/gelu-d12-chinchilla-261M-pytorch 1