VOOZH about

URL: https://huggingface.co/mkurman/NeuroBLAST-V3-SYNTH-EC-150000

⇱ mkurman/NeuroBLAST-V3-SYNTH-EC-150000 · Hugging Face


NeuroBLAST-V3-SYNTH-EC-150000

⚠️ EXPERIMENTAL EARLY CHECKPOINT ⚠️

This is an Early Checkpoint (EC) of the NeuroBLAST V3 architecture, a novel hybrid model designed with a biologically inspired "cortical" structure.

This specific checkpoint (150k steps) represents the "pre-decay" phase of training. It has been trained on short contexts with a high learning rate and is intended for architectural evaluation and research purposes.

Model Details

  • Architecture: NeuroBLAST V3 (Custom Hybrid Architecture)
  • Checkpoint Step: 150,000
  • Parameters: 596,728,320
  • Num layers: 72
    • Sensory layers: 24
    • Associative layers: 32
    • Motor layers: 16
  • Hidden size: 512
  • Vocab size: 65538
  • Intermediate size: 3072
  • Num attention heads: 16
  • Num kv heads: 8
  • Head dim: 128
  • Tie word embeddings: False

Architecture Highlights

NeuroBLAST differs from standard Transformers by utilizing a three-stage cortical design:

  1. Sensory Cortex: Hybrid layers alternating between Attention and Dilated Causal 2D Convolutions.
  2. Associative Cortex: Hybrid layers with alternating RoPE usage.
  3. Motor Cortex: Pure Attention layers.
  4. Deep Residual Bridges: Long-range residual connections injecting the original embeddings (and their negations) between cortical stages to improve signal propagation.

👁 architecture

Training Details

This model is currently being trained using the Google TPU Research Cloud (TRC).

  • Dataset: PleIAs/SYNTH
  • Tokens Processed: ~118 Billion
  • Hardware: TPUv4-16
  • Training Time: ~8 Days
  • Effective Batch Size: 1024
  • Context Length: 768 tokens (Current phase)
  • Learning rate: 4e-3
  • Weight decay: 0.0
  • Optimizer: AdamW
  • Precision: BFloat16
  • Current State: Pre-decay phase (No weight decay applied yet).

👁 eval_loss

Roadmap

This checkpoint marks the end of the initial warmup/learning phase. The next steps in training are:

  1. Significantly extending the context length.
  2. Lowering the learning rate.
  3. Introducing weight decay for convergence.

Usage

Note: You must use trust_remote_code=True as this model utilizes custom modeling code (modeling_neuroblast.py).

import torch
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM

model_id = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load the model with custom code trust
model = AutoModelForCausalLM.from_pretrained(
 model_id, 
 torch_dtype=torch.bfloat16, 
 device_map='cuda', 
 trust_remote_code=True
).eval()

streamer = TextStreamer(
 tokenizer, skip_prompt=False, decode_kwargs={"skip_special_tokens": False}
)

# Prepare input
input_ids = tokenizer.apply_chat_template(
 [{"role": "user", "content": "what is hypertension?"}], 
 tokenize=True, 
 return_tensors="pt", 
 add_generation_prompt=True
)

print(f"Input IDs: {input_ids}")

# Generate
with torch.no_grad():
 outputs = model.generate(
 input_ids=input_ids.to(model.device),
 max_new_tokens=128,
 streamer=streamer,
 use_cache=True,
 # Important: Keep repetition_penalty at 1.0 for this early checkpoint
 repetition_penalty=1.0, 
 )

You can find the underlying JAX implementation in the neuroblastv3_jax folder. (weights in a separate project)


import argparse
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from neuroblast3_jax.modeling_neuroblast_jax import NeuroBLASTForCausalLM as NeuroBLASTForCausalLMJax

def generate_text(model, tokenizer, text, max_new_tokens=50, temperature=0.7, top_k=50):
 inputs = tokenizer(f"user\n{text}<|im_end|><|im_start|>assistant\n", return_tensors="np")
 original_input_ids = inputs["input_ids"]
 batch_size, prompt_len = original_input_ids.shape
 total_len = prompt_len + max_new_tokens
 
 # Pad input_ids to total_len
 pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
 input_ids = jnp.full((batch_size, total_len), pad_id, dtype=jnp.int32)
 input_ids = input_ids.at[:, :prompt_len].set(original_input_ids)
 
 attention_mask = jnp.ones((batch_size, total_len), dtype=jnp.int32)
 params = model.params

 @jax.jit
 def model_step(params, input_ids, attention_mask, rng):
 outputs = model(input_ids=input_ids, attention_mask=attention_mask, params=params, train=False)
 return outputs.logits

 rng = jax.random.PRNGKey(0)
 
 print("Generating...")
 current_len = prompt_len
 printed_len = 0
 
 for i in range(max_new_tokens):
 rng, step_rng = jax.random.split(rng)
 
 # Run model
 logits = model_step(params, input_ids, attention_mask, step_rng)
 
 # Get logits for the last valid token (current_len - 1)
 next_token_logits = logits[:, current_len - 1, :]
 
 # Sampling
 scaled_logits = next_token_logits / temperature
 next_token = jax.random.categorical(step_rng, scaled_logits, axis=-1)
 
 # Update input_ids
 # We need to update the next position
 input_ids = input_ids.at[:, current_len].set(next_token)
 
 current_len += 1
 
 # Streaming output
 valid_ids = input_ids[0, :current_len]
 current_text = tokenizer.decode(valid_ids, skip_special_tokens=False)
 
 if i == 0:
 pass

 new_text = current_text[printed_len:]
 if new_text:
 print(new_text, end="", flush=True)
 printed_len += len(new_text)
 
 # Check EOS
 if next_token[0] == tokenizer.eos_token_id:
 break
 
 valid_ids = input_ids[0, :current_len]
 return tokenizer.decode(valid_ids, skip_special_tokens=False)


 checkpoint = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000-JAX"

 print(f"Loading model from {checkpoint}...")
 tokenizer = AutoTokenizer.from_pretrained(
 checkpoint,
 use_fast=True,
 trust_remote_code=True,
 )

 print(f"Available devices: {jax.devices()}")

 model = NeuroBLASTForCausalLMJax.from_pretrained(
 checkpoint,
 dtype=jnp.bfloat16, 
 trust_remote_code=True,
 is_decoder=True,
 )
 
 generated_text = generate_text(model, tokenizer, 'what is hypertension?', 128)
 
 print("\nGenerated Text:")
 print("-" * 20)
 print(generated_text)
 print("-" * 20)

Acknowledgments

This model was trained using Cloud TPUs provided by Google's TPU Research Cloud (TRC) program.

Special thanks to Pierre-Carl Langlais and the PleIAs team for the high-quality SYNTH dataset.

Repo

GitHub: https://github.com/mkurman/neuroblast-v3

Downloads last month
614
Safetensors
Model size
0.6B params
Tensor type
F32
·

Dataset used to train mkurman/NeuroBLAST-V3-SYNTH-EC-150000

Spaces using mkurman/NeuroBLAST-V3-SYNTH-EC-150000 2

Collection including mkurman/NeuroBLAST-V3-SYNTH-EC-150000