Qwen3-0.6B Lambda Gates — Chat-Template Variants
Lambda gates trained with chat templates enabled (enable_thinking=False), for deployment where the base model is used with its chat interface. Four variants spanning two axes:
- Activation scaling after gating (
scale_mode):meanvs.energy - NKE retention loss (
unmasked_retain_weight): 0.0, 0.05, 0.1
Variants
| Folder | scale_mode |
λ_f |
λ_r |
unmasked_retain_weight |
Notes |
|---|---|---|---|---|---|
chat_energy/ |
energy | 0.1 | 0.5 | 0.0 | Energy rescale, no NKE |
chat_energy_optA/ |
energy | 0.1 | 0.5 | 0.05 | Energy + light NKE |
chat_energy_optB/ |
energy | 1.0 | 0.5 | 0.1 | Energy + strong forget + moderate NKE |
chat_mean/ |
mean | 0.1 | 0.5 | 0.0 | Mean rescale, no NKE |
All variants share:
- Base:
Qwen/Qwen3-0.6B - Forget data: PopQA-mini entity-masked knowledge text
- Reasoning data: NuminaMath-CoT (10k seed subset), chat-templated
β=4.0, distill T=2.0,forget_retain_ratio=1:2, lr=1e-2 cosine, 3 epochs, bf16,init_logit_std=0.1use_chat_template=True,enable_thinking=Falsefor forget/oracle/conflict evals
Gate Statistics (chat_energy)
| Metric | Value |
|---|---|
| Total gates | 86,016 |
| Mean sigmoid gate | ≈ 0.500 |
| Std | ≈ 0.03 |
Selected thresholds (per selected_thresholds.txt) target 5 / 25 / 50 / 75 / 95% off-fractions.
Contents per variant
<variant>/
lambda_logits.pt # 86,016 per-neuron logits
neuron_indices.json # Knowledge neurons at threshold 0.5
gate_stats.json # Statistics + selected thresholds
selected_thresholds.txt # Comma-separated thresholds
Usage
import torch, json
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)
# Apply chat template at inference
messages = [{"role": "user", "content": "What is 2+2?"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
# Load a variant and apply gates (see baseline README for full recipe)
gate_state = torch.load("chat_energy_optA/lambda_logits.pt", map_location="cpu")
Scale Mode
After gating the hidden activations h with λ = sigmoid(logits):
mean:h = (h · λ) / mean(λ)— rescale to preserve mean activation magnitudeenergy:h = (h · λ) / sqrt(mean(λ²))— rescale to preserve activation energy
The energy mode tends to be slightly more stable at high off-fractions.
Related Checkpoints
- qwen3-0.6b-lambda-gates-baseline — non-chat baseline
- qwen3-0.6b-lambda-gates-nke — NKE variants without chat template
- qwen3-1.7b-lambda-gates-chat — 1.7B scale-up with matching variants
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support
