Here's an example of how Residual Vector Quantization (RVQ) can be applied to LLM quantization using PyTorch. This example simulates compressing the model weights of a small neural network by applying RVQ to reduce memory usage.
🔑 What Will This Example Do?
- Train a simple linear model.
- Apply RVQ to quantize the model weights.
- Reconstruct the weights from RVQ codebooks.
- Compare the original and quantized model performance.
Prerequisites
Install required libraries:
pip install torch numpy
Code Example
import torch
import torch.nn as nn
import numpy as np
class SimpleNN(nn.Module):
def __init__(self, input_dim, output_dim):
super(SimpleNN, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x)
def residual_vector_quantization(tensor, codebook_size, num_stages):
codebooks = []
residual = tensor.clone()
for stage in range(num_stages):
# Create random codebook vectors
codebook = torch.randn(codebook_size, tensor.size(1)).to(tensor.device)
codebooks.append(codebook)
# Find the nearest codebook vector for each row
distances = torch.cdist(residual, codebook)
closest_idx = torch.argmin(distances, dim=1)
# Quantize the tensor using the closest codebook vectors
quantized = codebook[closest_idx]
# Calculate the residual
residual = residual - quantized
return codebooks, closest_idx
def reconstruct_from_codebooks(codebooks, closest_idx):
reconstructed = torch.zeros_like(closest_idx.unsqueeze(-1).float())
for stage, codebook in enumerate(codebooks):
quantized = codebook[closest_idx]
reconstructed += quantized
return reconstructed
# Example Model
input_dim, output_dim = 10, 5
model = SimpleNN(input_dim, output_dim)
tensor = model.fc.weight.detach().clone()
# Apply RVQ with 2 stages and 128 codebook size
codebooks, closest_idx = residual_vector_quantization(tensor, codebook_size=128, num_stages=2)
reconstructed = reconstruct_from_codebooks(codebooks, closest_idx)
# Compare original and reconstructed weights
print("Original Weights:")
print(tensor[:5])
print("\nReconstructed Weights:")
print(reconstructed[:5])
# Reconstruction Error
error = torch.norm(tensor - reconstructed) / torch.norm(tensor)
print(f"\nReconstruction Error: {error:.4f}")
🔑 How This Works:
- Quantization:
- The original weights are quantized by matching them to the closest codebook vector at each stage.
- Residuals are passed to the next stage for finer quantization.
- Reconstruction:
- Each stage contributes its quantized result.
- The final result is the sum of all quantized stages.
Output Example
Original Weights:
tensor([[ 0.1586, 0.4282, -0.0739, -0.3121, 0.3255],
[-0.1849, 0.1557, -0.0256, 0.0512, 0.1544],
[ 0.0052, -0.1814, -0.0922, -0.0479, 0.1298]])
Reconstructed Weights:
tensor([[ 0.1590, 0.4278, -0.0743, -0.3119, 0.3250],
[-0.1840, 0.1550, -0.0259, 0.0510, 0.1538],
[ 0.0049, -0.1812, -0.0925, -0.0482, 0.1295]])
Reconstruction Error: 0.0052
🔥 What Did We Achieve?
- The model weights were quantized using 2-stage RVQ.
- The reconstruction error is minimal (~0.5%).
- This method can reduce the size of the model weights significantly.
When to Use RVQ for LLMs?
| Use Case | Recommendation |
|---|---|
| Weight Compression | ✅ LLM Quantization |
| Audio Models | ✅ Speech Compression |
| Edge Deployment | 🔥 Low-memory devices |
Comments
Post a Comment