Smol Mixtral
Published:
Project Overview
Project Overview
A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), designed for text generation and understanding tasks. This model is built on the Mixtral architecture with enhancements like Flash Attention, SWiGLU activation, and Liger kernels for optimized performance.
Technical Details
- Type: SmolHub Playground Project
- Framework: PyTorch
- Category: Experimental AI/ML
- Repository: Smol Mixtral
SmolMixtral - Mixtral Inspired Model
A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), designed for text generation and understanding tasks. This model is built on the Mixtral architecture with enhancements like Flash Attention, SWiGLU activation, and Liger kernels for optimized performance.
- So, I trained a MoE based a 124M (8x12M) architecture I coded from ground up.
- Trained on TinyStories dataset from HuggingFace consisting of 1M texts for a total of 14000 steps
Examples
Provided under the generated_data/
directory, these examples showcase the modelโs capabilities in text generation and understanding.
๐ Training Results & Model Weights
๐ View Training Report: SmolMixtral Training Results on WandB
๐พ Download Pre-trained Weights:
- Hugging Face Model: YuvrajSingh9886/SmolMixtral
- WandB Checkpoints: Check the WandB report above for additional trained model checkpoints
Features
- Flash Attention: Efficient attention mechanism with memory optimization
- Mixture of Experts (MoE): 8 experts with top-2 routing and noisy top-k support
- SWiGLU Activation: Advanced activation function in expert layers
- Rotary Positional Embeddings: Position encoding for sequence understanding
- Liger Kernels: Optimized kernels for faster training (optional)
- Distributed Training: Support for multi-GPU training with DDP
- Advanced Optimizer: AdamW optimizer with custom learning rate scheduling
- Gradio Interface: Interactive web interface for text generation
Model Architecture
Default Configuration
- Embedding Dimensions: 512
- Decoder Layers: 8
- Attention Heads: 8
- MoE Experts: 8 (top-2 routing)
- Block Size: 1024 tokens
- Vocabulary Size: Based on Llama-2-7b tokenizer (~32,000 tokens)
- Batch Size: 16
Full Parameter List
Model Architecture Parameters
epochs
: Number of training epochs (default: 4)block_size
: Maximum sequence length (default: 1024)batch_size
: Training batch size (default: 16)embeddings_dims
: Model embedding dimensions (default: 512)no_of_heads
: Number of attention heads (default: 8)no_of_decoder_layers
: Number of decoder layers (default: 8)attn_dropout
: Attention dropout rate (default: 0.1)dropout
: General dropout rate (default: 0.1)
Mixture of Experts (MoE) Parameters
experts
: Number of MoE experts (default: 8)top_experts
: Number of experts to route to (default: 2)noisy_topk
: Use noisy top-k routing (default: False)
Training Hyperparameters
max_lr
: Maximum learning rate (default: 6e-4)weight_decay_optim
: Weight decay for optimizer (default: 0.01)beta_1
: Beta1 for optimizer (default: 0.9)beta_2
: Beta2 for optimizer (default: 0.95)eps
: Epsilon for optimizer (default: 1e-8)clip
: Gradient clipping value (default: 1.0)
System Configuration
device
: Device to use (default: โcuda:9โ)use_checkpointing
: Use gradient checkpointing (default: False)use_liger
: Use Liger kernels for optimization (default: True)use_flash_attention
: Use Flash Attention (default: True)use_compile
: Use torch.compile (default: True)
Data Configuration
vocab_size
: Vocabulary size (default: based on tokenizer + 768)val_epochs
: Validation frequency (default: 2)
Quick Start
Installation
chmod +x install.sh
./install.sh
Important: Hugging Face Token Setup
Since this model uses the Llama-2 tokenizer, youโll need a Hugging Face token to access the gated model.
- Get a Hugging Face Token:
- Go to Hugging Face Settings
- Create a new token with โReadโ permissions
- Accept the Llama-2 license at meta-llama/Llama-2-7b-hf
- Set your token in config.py:
TOKEN = 'your_token_here'
Using Pre-trained Weights
- Download Model Weights:
- Option 1: Download from Hugging Face - YuvrajSingh9886/SmolMixtral
- Option 2: Visit the WandB Training Report for additional checkpoints
- Place downloaded files in the
checkpoints/
directory
- Load Pre-trained Model for Inference:
# Using the Gradio web interface cd gradio python app.py # Or use in your own code python inference.py
Training Examples
Basic Training (Single GPU)
python trainer.py
Training with Custom Parameters
# Train with larger model (modify config.py)
python trainer.py
# Train with different dataset (modify data.py)
python trainer.py
Multi-GPU Distributed Training
# 2 GPUs
torchrun --nproc_per_node=2 trainer.py
# 4 GPUs
torchrun --nproc_per_node=4 trainer.py
# 8 GPUs
torchrun --nproc_per_node=8 trainer.py
Inference with Gradio
HF_TOKEN should be set in config.py
to use the Gradio interface. Moreover, set your token as follows:
export HF_TOKEN=<TOKEN_HERE>
# Run the Gradio app
cd gradio
python app.py
# With custom checkpoint (edit app.py to point to your checkpoint)
cd gradio
python app.py
File Structure
SmolMixtral/
โโโ config.py # Model configuration and hyperparameters
โโโ model.py # Model architecture (Mixtral, MoE, Attention, etc.)
โโโ data.py # Data loading and preparation
โโโ inference.py # Inference functions and text generation
โโโ trainer.py # Main training loop with DDP support
โโโ install.sh # Setup script
โโโ requirements.txt # Python dependencies
โโโ model_summary.py # Model architecture summary
โโโ gradio/
โ โโโ app.py # Gradio web interface
โโโ checkpoints/ # Model checkpoints
โโโ generated_data/ # Generated text outputs
โโโ images/ # Project images
โโโ old/ # Original files
Training Features
- Gradient Accumulation: Configurable batch size scaling
- Learning Rate Scheduling: Cosine decay with warmup
- Gradient Clipping: Prevents gradient explosion
- Wandb Integration: Experiment tracking and logging
- Checkpointing: Regular model checkpoints during training
- Loss Calculation: Optimized cross-entropy with padding token handling
- Distributed Training: Multi-GPU support with DDP
- Memory Optimization: Gradient checkpointing support
Generation Methods
- Top-k Sampling: Traditional sampling with temperature control
Advanced Usage
Configuration
All parameters can be configured by modifying config.py
:
@dataclass
class ModelArgs:
epochs = 4
block_size = 1024
batch_size = 16
embeddings_dims = 512
# ... other parameters
Custom Dataset Training
Modify data.py
to use different datasets:
# TinyStories (default)
tinystories = True
fw = False
# FineWeb
tinystories = False
fw = True
Monitoring and Logging
Training automatically logs to WandB with project name โMixtral-DDP-Pretrain-10-billion-tokensโ
Performance Tips
- Use Liger Kernels: Keep
use_liger = True
for optimized operations - Flash Attention: Keep
use_flash_attention = True
for memory efficiency - Gradient Checkpointing: Use
use_checkpointing = True
for memory-constrained setups - Batch Size Tuning: Start with smaller batch sizes and increase gradually
- Block Size: Larger block sizes improve quality but require more memory
Troubleshooting
Common Issues
Authentication Error (401)
# Make sure you have accepted the Llama-2 license and have a valid token
# Visit: https://huggingface.co/meta-llama/Llama-2-7b-hf
# Then set your token in config.py
Out of Memory (OOM)
# Reduce batch size and enable checkpointing in config.py
batch_size = 8
use_checkpointing = True
Slow Training
# Enable optimizations in config.py
use_liger = True
use_flash_attention = True
use_compile = True
Contributing
Feel free to contribute improvements, bug fixes, or new features!
Requirements
- Python 3.8+
- PyTorch 2.0+
- Transformers
- Datasets
- Gradio
- Wandb
- Liger-kernel (optional)
License
MIT License
Source Code
๐ GitHub Repository: Smol Mixtral
View the complete implementation, documentation, and examples on GitHub.
Interactive Features
๐ฎ Web Interface: This project includes a Gradio-based web interface for easy interaction and experimentation.
๐ฑ User-Friendly: Simple, intuitive interface perfect for testing and learning.
This project is part of the SmolHub Playground collection - a space for experimental AI models and proof-of-concept implementations.