Story Mixtral
Published:
Project Overview
A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), Flash Attention, and other advanced features.
Technical Details
- Type: SmolHub Playground Project
- Framework: PyTorch
- Category: Experimental AI/ML
- Repository: Story Mixtral
StoryMixtral - Mixtral Inspired Model
A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), Flash Attention, and other advanced features.
Examples
Provided under the generated_data/
directory, these examples showcase the modelโs capabilities in text generation and understanding.
๐ Training Results & Model Weights
๐ View Training Report: StoryMixtral Training Results on WandB
๐พ Download Pre-trained Weights:
- Hugging Face Model: YuvrajSingh9886/StoryMixtral
- WandB Checkpoints: Check the WandB report above for additional trained model checkpoints
Features
- Flash Attention: Efficient attention mechanism with memory optimization
- Mixture of Experts (MoE): 8 experts with top-2 routing and noisy top-k support
- SWiGLU Activation: Advanced activation function in expert layers
- Rotary Positional Embeddings: Position encoding for sequence understanding
- Liger Kernels: Optimized kernels for faster training (optional)
- Distributed Training: Support for multi-GPU training with DDP
- Advanced Optimizer: AdamW optimizer with custom learning rate scheduling
- Gradio Interface: Interactive web interface for text generation
Model Architecture
Default Configuration
- Embedding Dimensions: 512
- Decoder Layers: 8
- Attention Heads: 8
- MoE Experts: 8 (top-2 routing)
- Block Size: 1024 tokens
- Vocabulary Size: Based on Llama-2-7b tokenizer (~32,000 tokens)
- Batch Size: 16
Full Parameter List
Model Architecture Parameters
epochs
: Number of training epochs (default: 4)block_size
: Maximum sequence length (default: 1024)batch_size
: Training batch size (default: 16)embeddings_dims
: Model embedding dimensions (default: 512)no_of_heads
: Number of attention heads (default: 8)no_of_decoder_layers
: Number of decoder layers (default: 8)attn_dropout
: Attention dropout rate (default: 0.1)dropout
: General dropout rate (default: 0.1)
Mixture of Experts (MoE) Parameters
experts
: Number of MoE experts (default: 8)top_experts
: Number of experts to route to (default: 2)noisy_topk
: Use noisy top-k routing (default: False)
Training Hyperparameters
max_lr
: Maximum learning rate (default: 6e-4)weight_decay_optim
: Weight decay for optimizer (default: 0.01)beta_1
: Beta1 for optimizer (default: 0.9)beta_2
: Beta2 for optimizer (default: 0.95)eps
: Epsilon for optimizer (default: 1e-8)clip
: Gradient clipping value (default: 1.0)
System Configuration
device
: Device to use (default: โcuda:9โ)use_checkpointing
: Use gradient checkpointing (default: False)use_liger
: Use Liger kernels for optimization (default: True)use_flash_attention
: Use Flash Attention (default: True)use_compile
: Use torch.compile (default: True)
Data Configuration
vocab_size
: Vocabulary size (default: based on tokenizer + 768)val_epochs
: Validation frequency (default: 2)
Quick Start
Installation
chmod +x install.sh
./install.sh
Important: Hugging Face Token Setup
Since this model uses the Llama-2 tokenizer, youโll need a Hugging Face token to access the gated model.
- Get a Hugging Face Token:
- Go to Hugging Face Settings
- Create a new token with โReadโ permissions
- Accept the Llama-2 license at meta-llama/Llama-2-7b-hf
- Set your token in config.py:
TOKEN = 'your_token_here'
Using Pre-trained Weights
- Download Model Weights:
- Option 1: Download from Hugging Face - YuvrajSingh9886/StoryMixtral
- Option 2: Visit the WandB Training Report for additional checkpoints
- Place downloaded files in the
checkpoints/
directory
- Load Pre-trained Model for Inference:
# Using the Gradio web interface cd gradio python app.py # Or use in your own code python inference.py
Training Examples
Basic Training (Single GPU)
python trainer.py
Training with Custom Parameters
# Train with larger model (modify config.py)
python trainer.py
# Train with different dataset (modify data.py)
python trainer.py
Multi-GPU Distributed Training
# 2 GPUs
torchrun --nproc_per_node=2 trainer.py
# 4 GPUs
torchrun --nproc_per_node=4 trainer.py
# 8 GPUs
torchrun --nproc_per_node=8 trainer.py
Inference with Gradio
HF_TOKEN should be set in config.py
to use the Gradio interface. Moreover, set your token as follows:
export HF_TOKEN=<TOKEN_HERE>
# Run the Gradio app
cd gradio
python app.py
# With custom checkpoint (edit app.py to point to your checkpoint)
cd gradio
python app.py
File Structure
StoryMixtral/
โโโ config.py # Model configuration and hyperparameters
โโโ model.py # Model architecture (Mixtral, MoE, Attention, etc.)
โโโ data.py # Data loading and preparation
โโโ inference.py # Inference functions and text generation
โโโ trainer.py # Main training loop with DDP support
โโโ install.sh # Setup script
โโโ requirements.txt # Python dependencies
โโโ model_summary.py # Model architecture summary
โโโ gradio/
โ โโโ app.py # Gradio web interface
โโโ checkpoints/ # Model checkpoints
โโโ generated_data/ # Generated text outputs
โโโ images/ # Project images
โโโ old/ # Original files
Training Features
- Gradient Accumulation: Configurable batch size scaling
- Learning Rate Scheduling: Cosine decay with warmup
- Gradient Clipping: Prevents gradient explosion
- Wandb Integration: Experiment tracking and logging
- Checkpointing: Regular model checkpoints during training
- Loss Calculation: Optimized cross-entropy with padding token handling
- Distributed Training: Multi-GPU support with DDP
- Memory Optimization: Gradient checkpointing support
Generation Methods
- Top-k Sampling: Traditional sampling with temperature control
Advanced Usage
Configuration
All parameters can be configured by modifying config.py
:
@dataclass
class ModelArgs:
epochs = 4
block_size = 1024
batch_size = 16
embeddings_dims = 512
# ... other parameters
Custom Dataset Training
Modify data.py
to use different datasets:
# TinyStories (default)
tinystories = True
fw = False
# FineWeb
tinystories = False
fw = True
Monitoring and Logging
Training automatically logs to WandB with project name โMixtral-DDP-Pretrain-10-billion-tokensโ
Performance Tips
- Use Liger Kernels: Keep
use_liger = True
for optimized operations - Flash Attention: Keep
use_flash_attention = True
for memory efficiency - Gradient Checkpointing: Use
use_checkpointing = True
for memory-constrained setups - Batch Size Tuning: Start with smaller batch sizes and increase gradually
- Block Size: Larger block sizes improve quality but require more memory
Troubleshooting
Common Issues
Authentication Error (401)
# Make sure you have accepted the Llama-2 license and have a valid token
# Visit: https://huggingface.co/meta-llama/Llama-2-7b-hf
# Then set your token in config.py
Out of Memory (OOM)
# Reduce batch size and enable checkpointing in config.py
batch_size = 8
use_checkpointing = True
Slow Training
# Enable optimizations in config.py
use_liger = True
use_flash_attention = True
use_compile = True
Contributing
Feel free to contribute improvements, bug fixes, or new features!
Requirements
- Python 3.8+
- PyTorch 2.0+
- Transformers
- Datasets
- Gradio
- Wandb
- Liger-kernel (optional)
License
MIT License
Source Code
๐ GitHub Repository: Story Mixtral
View the complete implementation, documentation, and examples on GitHub.
Interactive Features
๐ฎ Web Interface: This project includes a Gradio-based web interface for easy interaction and experimentation.
๐ฑ User-Friendly: Simple, intuitive interface perfect for testing and learning.
This project is part of the SmolHub Playground collection - a space for experimental AI models and proof-of-concept implementations.