Enable native ModelOpt quantization support (3/3) (#10154)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
@@ -110,6 +110,157 @@ python3 -m sglang.launch_server \
|
|||||||
--port 30000 --host 0.0.0.0
|
--port 30000 --host 0.0.0.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using [NVIDIA ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
|
||||||
|
|
||||||
|
NVIDIA Model Optimizer (ModelOpt) provides advanced quantization techniques optimized for NVIDIA hardware. SGLang includes a streamlined workflow for quantizing models with ModelOpt and automatically exporting them for deployment.
|
||||||
|
|
||||||
|
##### Installation
|
||||||
|
|
||||||
|
First, install ModelOpt. You can either install it directly or as an optional SGLang dependency:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Option 1: Install ModelOpt directly
|
||||||
|
pip install nvidia-modelopt
|
||||||
|
|
||||||
|
# Option 2: Install SGLang with ModelOpt support (recommended)
|
||||||
|
pip install sglang[modelopt]
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Quantization and Export Workflow
|
||||||
|
|
||||||
|
SGLang provides an example script that demonstrates the complete ModelOpt quantization and export workflow:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quantize and export a model using ModelOpt FP8 quantization
|
||||||
|
python examples/usage/modelopt_quantize_and_export.py quantize \
|
||||||
|
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
|
||||||
|
--export-dir ./quantized_tinyllama_fp8 \
|
||||||
|
--quantization-method modelopt_fp8
|
||||||
|
|
||||||
|
# For FP4 quantization
|
||||||
|
python examples/usage/modelopt_quantize_and_export.py quantize \
|
||||||
|
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
|
||||||
|
--export-dir ./quantized_tinyllama_fp4 \
|
||||||
|
--quantization-method modelopt_fp4
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Available Quantization Methods
|
||||||
|
|
||||||
|
- `modelopt_fp8`: FP8 quantization with optimal performance on NVIDIA Hopper and Blackwell GPUs
|
||||||
|
- `modelopt_fp4`: FP4 quantization with optimal performance on Nvidia Blackwell GPUs
|
||||||
|
|
||||||
|
##### Python API Usage
|
||||||
|
|
||||||
|
You can also use ModelOpt quantization programmatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.model_loader.loader import get_model_loader
|
||||||
|
|
||||||
|
# Configure model with ModelOpt quantization and export
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
quantization="modelopt_fp8", # or "modelopt_fp4"
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_config = LoadConfig(
|
||||||
|
modelopt_export_path="./exported_model",
|
||||||
|
modelopt_checkpoint_save_path="./checkpoint.pth", # optional, fake quantized checkpoint
|
||||||
|
)
|
||||||
|
device_config = DeviceConfig(device="cuda")
|
||||||
|
|
||||||
|
# Load and quantize the model (export happens automatically)
|
||||||
|
model_loader = get_model_loader(load_config, model_config)
|
||||||
|
quantized_model = model_loader.load_model(
|
||||||
|
model_config=model_config,
|
||||||
|
device_config=device_config,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Deploying Quantized Models
|
||||||
|
|
||||||
|
After quantization and export, you can deploy the model with SGLang:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Deploy the exported quantized model
|
||||||
|
python -m sglang.launch_server \
|
||||||
|
--model-path ./quantized_tinyllama_fp8 \
|
||||||
|
--quantization modelopt \
|
||||||
|
--port 30000 --host 0.0.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Or using the Python API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
# Deploy exported ModelOpt quantized model
|
||||||
|
llm = sgl.Engine(
|
||||||
|
model_path="./quantized_tinyllama_fp8",
|
||||||
|
quantization="modelopt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
prompts = ["Hello, how are you?", "What is the capital of France?"]
|
||||||
|
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 100}
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
print(f"Prompt: {prompts[i]}")
|
||||||
|
print(f"Output: {output.outputs[0].text}")
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Advanced Features
|
||||||
|
|
||||||
|
**Checkpoint Management**: Save and restore fake quantized checkpoints for reuse:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Save the fake quantized checkpoint during quantization
|
||||||
|
python examples/usage/modelopt_quantize_and_export.py quantize \
|
||||||
|
--model-path meta-llama/Llama-3.2-1B-Instruct \
|
||||||
|
--export-dir ./quantized_model \
|
||||||
|
--quantization-method modelopt_fp8 \
|
||||||
|
--checkpoint-save-path ./my_checkpoint.pth
|
||||||
|
|
||||||
|
# The checkpoint can be reused for future quantization runs and skip calibration
|
||||||
|
```
|
||||||
|
|
||||||
|
**Export-only Workflow**: If you have a pre-existing fake quantized ModelOpt checkpoint, you can export it directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.model_loader.loader import get_model_loader
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_path="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
quantization="modelopt_fp8",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_config = LoadConfig(
|
||||||
|
modelopt_checkpoint_restore_path="./my_checkpoint.pth",
|
||||||
|
modelopt_export_path="./exported_model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load and export the model
|
||||||
|
model_loader = get_model_loader(load_config, model_config)
|
||||||
|
model_loader.load_model(model_config=model_config, device_config=DeviceConfig())
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Benefits of ModelOpt
|
||||||
|
|
||||||
|
- **Hardware Optimization**: Specifically optimized for NVIDIA GPU architectures
|
||||||
|
- **Advanced Quantization**: Supports cutting-edge FP8 and FP4 quantization techniques
|
||||||
|
- **Seamless Integration**: Automatic export to HuggingFace format for easy deployment
|
||||||
|
- **Calibration-based**: Uses calibration datasets for optimal quantization quality
|
||||||
|
- **Production Ready**: Enterprise-grade quantization with NVIDIA support
|
||||||
|
|
||||||
## Online Quantization
|
## Online Quantization
|
||||||
|
|
||||||
To enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`:
|
To enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`:
|
||||||
@@ -148,5 +299,6 @@ python3 -m sglang.launch_server \
|
|||||||
|
|
||||||
- [GPTQModel](https://github.com/ModelCloud/GPTQModel)
|
- [GPTQModel](https://github.com/ModelCloud/GPTQModel)
|
||||||
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
|
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
|
||||||
|
- [NVIDIA Model Optimizer (ModelOpt)](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
|
||||||
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
|
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
|
||||||
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
|
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
|
||||||
|
|||||||
303
examples/usage/modelopt_quantize_and_export.py
Executable file
303
examples/usage/modelopt_quantize_and_export.py
Executable file
@@ -0,0 +1,303 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Example: ModelOpt Quantization and Export with SGLang
|
||||||
|
|
||||||
|
This example demonstrates the streamlined workflow for quantizing a model with
|
||||||
|
ModelOpt and automatically exporting it for deployment with SGLang.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.distributed.parallel_state import (
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_loader.loader import get_model_loader
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_export(export_dir: str) -> bool:
|
||||||
|
"""Validate that an exported model directory contains the expected files."""
|
||||||
|
import glob
|
||||||
|
|
||||||
|
required_files = ["config.json", "tokenizer_config.json"]
|
||||||
|
|
||||||
|
if not os.path.exists(export_dir):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check required files
|
||||||
|
for file in required_files:
|
||||||
|
if not os.path.exists(os.path.join(export_dir, file)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for model files using pattern matching to handle sharded models
|
||||||
|
model_patterns = [
|
||||||
|
"model*.safetensors",
|
||||||
|
"pytorch_model*.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
has_model_file = False
|
||||||
|
for pattern in model_patterns:
|
||||||
|
matching_files = glob.glob(os.path.join(export_dir, pattern))
|
||||||
|
if matching_files:
|
||||||
|
has_model_file = True
|
||||||
|
break
|
||||||
|
|
||||||
|
return has_model_file
|
||||||
|
|
||||||
|
|
||||||
|
def _get_export_info(export_dir: str) -> Optional[dict]:
|
||||||
|
"""Get information about an exported model."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
if not _validate_export(export_dir):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_path = os.path.join(export_dir, "config.json")
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model_type": config.get("model_type", "unknown"),
|
||||||
|
"architectures": config.get("architectures", []),
|
||||||
|
"quantization_config": config.get("quantization_config", {}),
|
||||||
|
"export_dir": export_dir,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_and_export_model(
|
||||||
|
model_path: str,
|
||||||
|
export_dir: str,
|
||||||
|
quantization_method: str = "modelopt_fp8",
|
||||||
|
checkpoint_save_path: Optional[str] = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Quantize a model with ModelOpt and export it for SGLang deployment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the original model
|
||||||
|
export_dir: Directory to export the quantized model
|
||||||
|
quantization_method: Quantization method ("modelopt_fp8" or "modelopt_fp4")
|
||||||
|
checkpoint_save_path: Optional path to save ModelOpt checkpoint
|
||||||
|
device: Device to use for quantization
|
||||||
|
"""
|
||||||
|
print("🚀 Starting ModelOpt quantization and export workflow")
|
||||||
|
print(f"📥 Input model: {model_path}")
|
||||||
|
print(f"📤 Export directory: {export_dir}")
|
||||||
|
print(f"⚙️ Quantization method: {quantization_method}")
|
||||||
|
|
||||||
|
# Initialize minimal distributed environment for single GPU quantization
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
print("🔧 Initializing distributed environment...")
|
||||||
|
# Set up environment variables for single-process distributed
|
||||||
|
os.environ["RANK"] = "0"
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12355" # Use a different port than tests
|
||||||
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
local_rank=0,
|
||||||
|
backend="nccl" if device == "cuda" else "gloo",
|
||||||
|
)
|
||||||
|
initialize_model_parallel(
|
||||||
|
tensor_model_parallel_size=1,
|
||||||
|
pipeline_model_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure model loading with ModelOpt quantization and export
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_path=model_path,
|
||||||
|
quantization=quantization_method, # Use unified quantization flag
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_config = LoadConfig(
|
||||||
|
modelopt_checkpoint_save_path=checkpoint_save_path,
|
||||||
|
modelopt_export_path=export_dir,
|
||||||
|
)
|
||||||
|
device_config = DeviceConfig(device=device)
|
||||||
|
|
||||||
|
# Load and quantize the model (export happens automatically)
|
||||||
|
print("🔄 Loading and quantizing model...")
|
||||||
|
model_loader = get_model_loader(load_config, model_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_loader.load_model(
|
||||||
|
model_config=model_config,
|
||||||
|
device_config=device_config,
|
||||||
|
)
|
||||||
|
print("✅ Model quantized successfully!")
|
||||||
|
|
||||||
|
# Validate the export
|
||||||
|
if _validate_export(export_dir):
|
||||||
|
print("✅ Export validation passed!")
|
||||||
|
|
||||||
|
info = _get_export_info(export_dir)
|
||||||
|
if info:
|
||||||
|
print("📋 Model info:")
|
||||||
|
print(f" - Type: {info['model_type']}")
|
||||||
|
print(f" - Architecture: {info['architectures']}")
|
||||||
|
print(f" - Quantization: {info['quantization_config']}")
|
||||||
|
else:
|
||||||
|
print("❌ Export validation failed!")
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Quantization failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n🎉 Workflow completed successfully!")
|
||||||
|
print(f"📁 Quantized model exported to: {export_dir}")
|
||||||
|
print("\n🚀 To use the exported model:")
|
||||||
|
print(
|
||||||
|
f" python -m sglang.launch_server --model-path {export_dir} --quantization modelopt"
|
||||||
|
)
|
||||||
|
print("\n # Or in Python:")
|
||||||
|
print(" import sglang as sgl")
|
||||||
|
print(f" llm = sgl.Engine(model_path='{export_dir}', quantization='modelopt')")
|
||||||
|
print(" # Note: 'modelopt' auto-detects FP4/FP8 from model config")
|
||||||
|
|
||||||
|
|
||||||
|
def deploy_exported_model(
|
||||||
|
export_dir: str,
|
||||||
|
host: str = "127.0.0.1",
|
||||||
|
port: int = 30000,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Deploy an exported ModelOpt quantized model with SGLang.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory containing the exported model
|
||||||
|
host: Host to bind the server to
|
||||||
|
port: Port to bind the server to
|
||||||
|
"""
|
||||||
|
print(f"🚀 Deploying exported model from: {export_dir}")
|
||||||
|
|
||||||
|
# Validate export first
|
||||||
|
if not _validate_export(export_dir):
|
||||||
|
print("❌ Invalid export directory!")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Launch SGLang engine with the exported model
|
||||||
|
# Using generic "modelopt" for auto-detection of FP4/FP8
|
||||||
|
llm = sgl.Engine(
|
||||||
|
model_path=export_dir,
|
||||||
|
quantization="modelopt",
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✅ Model deployed successfully!")
|
||||||
|
print(f"🌐 Server running at http://{host}:{port}")
|
||||||
|
|
||||||
|
# Example inference
|
||||||
|
prompts = ["Hello, how are you?", "What is the capital of France?"]
|
||||||
|
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 100}
|
||||||
|
|
||||||
|
print("\n🧪 Running example inference...")
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
print(f"Prompt {i+1}: {prompts[i]}")
|
||||||
|
print(f"Output: {output['text']}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Deployment failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="ModelOpt Quantization and Export with SGLang",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
# Quantize and export a model (recommended workflow)
|
||||||
|
python modelopt_quantize_and_export.py quantize \\
|
||||||
|
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\
|
||||||
|
--export-dir ./quantized_model \\
|
||||||
|
--quantization-method modelopt_fp8
|
||||||
|
|
||||||
|
# Deploy a pre-exported model
|
||||||
|
python modelopt_quantize_and_export.py deploy \\
|
||||||
|
--export-dir ./quantized_model
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
|
# Quantize command
|
||||||
|
quantize_parser = subparsers.add_parser(
|
||||||
|
"quantize", help="Quantize and export a model"
|
||||||
|
)
|
||||||
|
quantize_parser.add_argument(
|
||||||
|
"--model-path", required=True, help="Path to the model to quantize"
|
||||||
|
)
|
||||||
|
quantize_parser.add_argument(
|
||||||
|
"--export-dir", required=True, help="Directory to export the quantized model"
|
||||||
|
)
|
||||||
|
quantize_parser.add_argument(
|
||||||
|
"--quantization-method",
|
||||||
|
choices=["modelopt_fp8", "modelopt_fp4"],
|
||||||
|
default="modelopt_fp8",
|
||||||
|
help="Quantization method to use",
|
||||||
|
)
|
||||||
|
quantize_parser.add_argument(
|
||||||
|
"--checkpoint-save-path", help="Optional path to save ModelOpt checkpoint"
|
||||||
|
)
|
||||||
|
quantize_parser.add_argument(
|
||||||
|
"--device", default="cuda", help="Device to use for quantization"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Quantize-and-serve command removed due to compatibility issues
|
||||||
|
# Use the separate quantize-then-deploy workflow instead
|
||||||
|
|
||||||
|
# Deploy command
|
||||||
|
deploy_parser = subparsers.add_parser("deploy", help="Deploy an exported model")
|
||||||
|
deploy_parser.add_argument(
|
||||||
|
"--export-dir", required=True, help="Directory containing the exported model"
|
||||||
|
)
|
||||||
|
deploy_parser.add_argument(
|
||||||
|
"--host", default="127.0.0.1", help="Host to bind the server to"
|
||||||
|
)
|
||||||
|
deploy_parser.add_argument(
|
||||||
|
"--port", type=int, default=30000, help="Port to bind the server to"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "quantize":
|
||||||
|
quantize_and_export_model(
|
||||||
|
model_path=args.model_path,
|
||||||
|
export_dir=args.export_dir,
|
||||||
|
quantization_method=args.quantization_method,
|
||||||
|
checkpoint_save_path=args.checkpoint_save_path,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
elif args.command == "deploy":
|
||||||
|
deploy_exported_model(
|
||||||
|
export_dir=args.export_dir,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -75,12 +75,7 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tracing = [
|
modelopt = ["nvidia-modelopt"]
|
||||||
"opentelemetry-api",
|
|
||||||
"opentelemetry-exporter-otlp",
|
|
||||||
"opentelemetry-exporter-otlp-proto-grpc",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
]
|
|
||||||
test = [
|
test = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"expecttest",
|
"expecttest",
|
||||||
@@ -107,6 +102,12 @@ cu130_all = [
|
|||||||
"sglang[decord]",
|
"sglang[decord]",
|
||||||
"sglang[cu130]"
|
"sglang[cu130]"
|
||||||
]
|
]
|
||||||
|
tracing = [
|
||||||
|
"opentelemetry-api",
|
||||||
|
"opentelemetry-exporter-otlp",
|
||||||
|
"opentelemetry-exporter-otlp-proto-grpc",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
]
|
||||||
|
|
||||||
# To be deprecated in 2 weeks
|
# To be deprecated in 2 weeks
|
||||||
blackwell = ["sglang[dev]"]
|
blackwell = ["sglang[dev]"]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
|
from sglang.srt.configs.modelopt_config import ModelOptConfig
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -51,6 +52,11 @@ class LoadConfig:
|
|||||||
decryption_key_file: If set, decrypts the output files with a password read
|
decryption_key_file: If set, decrypts the output files with a password read
|
||||||
from this file (after PBKDF2).
|
from this file (after PBKDF2).
|
||||||
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
|
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
|
||||||
|
|
||||||
|
# ModelOpt-specific loading options
|
||||||
|
modelopt_checkpoint_restore_path: Optional[str] = None
|
||||||
|
modelopt_checkpoint_save_path: Optional[str] = None
|
||||||
|
modelopt_export_path: Optional[str] = None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||||
@@ -64,6 +70,14 @@ class LoadConfig:
|
|||||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||||
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# ModelOpt-specific loading options
|
||||||
|
modelopt_checkpoint_restore_path: Optional[str] = None
|
||||||
|
modelopt_checkpoint_save_path: Optional[str] = None
|
||||||
|
modelopt_export_path: Optional[str] = None
|
||||||
|
|
||||||
|
# ModelOpt configuration object
|
||||||
|
modelopt_config: Optional[ModelOptConfig] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
if isinstance(model_loader_extra_config, str):
|
if isinstance(model_loader_extra_config, str):
|
||||||
@@ -78,6 +92,14 @@ class LoadConfig:
|
|||||||
else:
|
else:
|
||||||
self.ignore_patterns = ["original/**/*"]
|
self.ignore_patterns = ["original/**/*"]
|
||||||
|
|
||||||
|
# Create ModelOptConfig if not provided
|
||||||
|
if self.modelopt_config is None:
|
||||||
|
self.modelopt_config = ModelOptConfig(
|
||||||
|
checkpoint_restore_path=self.modelopt_checkpoint_restore_path,
|
||||||
|
checkpoint_save_path=self.modelopt_checkpoint_save_path,
|
||||||
|
export_path=self.modelopt_export_path,
|
||||||
|
)
|
||||||
|
|
||||||
def _verify_load_format(self) -> None:
|
def _verify_load_format(self) -> None:
|
||||||
if not isinstance(self.load_format, str):
|
if not isinstance(self.load_format, str):
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, List, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@@ -89,7 +89,6 @@ class ModelConfig:
|
|||||||
enable_multimodal: Optional[bool] = None,
|
enable_multimodal: Optional[bool] = None,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
modelopt_quant: Optional[Union[str, Dict]] = None,
|
|
||||||
override_config_file: Optional[str] = None,
|
override_config_file: Optional[str] = None,
|
||||||
is_draft_model: bool = False,
|
is_draft_model: bool = False,
|
||||||
hybrid_kvcache_ratio: Optional[
|
hybrid_kvcache_ratio: Optional[
|
||||||
@@ -97,15 +96,19 @@ class ModelConfig:
|
|||||||
] = None, # TODO: remove this, it is not a model config
|
] = None, # TODO: remove this, it is not a model config
|
||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
sampling_defaults: str = "openai",
|
sampling_defaults: str = "openai",
|
||||||
|
quantize_and_serve: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
self.modelopt_quant = modelopt_quant
|
|
||||||
self.is_draft_model = is_draft_model
|
self.is_draft_model = is_draft_model
|
||||||
self.model_impl = model_impl
|
self.model_impl = model_impl
|
||||||
self.sampling_defaults = sampling_defaults
|
self.sampling_defaults = sampling_defaults
|
||||||
|
self.quantize_and_serve = quantize_and_serve
|
||||||
|
|
||||||
|
# Validate quantize_and_serve configuration
|
||||||
|
self._validate_quantize_and_serve_config()
|
||||||
|
|
||||||
# Get hf config
|
# Get hf config
|
||||||
self._maybe_pull_model_tokenizer_from_remote()
|
self._maybe_pull_model_tokenizer_from_remote()
|
||||||
@@ -219,10 +222,10 @@ class ModelConfig:
|
|||||||
enable_multimodal=server_args.enable_multimodal,
|
enable_multimodal=server_args.enable_multimodal,
|
||||||
dtype=server_args.dtype,
|
dtype=server_args.dtype,
|
||||||
quantization=server_args.quantization,
|
quantization=server_args.quantization,
|
||||||
modelopt_quant=server_args.modelopt_quant,
|
|
||||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||||
model_impl=server_args.model_impl,
|
model_impl=server_args.model_impl,
|
||||||
sampling_defaults=server_args.sampling_defaults,
|
sampling_defaults=server_args.sampling_defaults,
|
||||||
|
quantize_and_serve=server_args.quantize_and_serve,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -547,6 +550,56 @@ class ModelConfig:
|
|||||||
# Default to FP8 for backward compatibility
|
# Default to FP8 for backward compatibility
|
||||||
return {"quant_method": "modelopt_fp8"}
|
return {"quant_method": "modelopt_fp8"}
|
||||||
|
|
||||||
|
def _is_already_quantized(self) -> bool:
|
||||||
|
"""Check if the model is already quantized based on config files."""
|
||||||
|
# Check for HuggingFace quantization config
|
||||||
|
from sglang.srt.utils import has_hf_quant_config
|
||||||
|
|
||||||
|
return has_hf_quant_config(self.model_path)
|
||||||
|
|
||||||
|
def _get_modelopt_quant_type(self) -> str:
|
||||||
|
"""Extract ModelOpt quantization type from unified quantization flag."""
|
||||||
|
if self.quantization == "modelopt_fp8":
|
||||||
|
return "fp8"
|
||||||
|
elif self.quantization == "modelopt_fp4":
|
||||||
|
return "nvfp4"
|
||||||
|
elif self.quantization == "modelopt":
|
||||||
|
# Auto-detect from model config
|
||||||
|
quant_cfg = self._parse_quant_hf_config()
|
||||||
|
if quant_cfg:
|
||||||
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||||
|
if "fp4" in quant_method:
|
||||||
|
return "fp4"
|
||||||
|
elif "fp8" in quant_method:
|
||||||
|
return "fp8"
|
||||||
|
# Default to fp8 if can't detect
|
||||||
|
return "fp8"
|
||||||
|
else:
|
||||||
|
return "fp8" # Default fallback
|
||||||
|
|
||||||
|
def _validate_quantize_and_serve_config(self):
|
||||||
|
"""Validate quantize_and_serve configuration."""
|
||||||
|
if not self.quantize_and_serve:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if ModelOpt quantization is specified
|
||||||
|
modelopt_quantization_specified = self.quantization in [
|
||||||
|
"modelopt",
|
||||||
|
"modelopt_fp8",
|
||||||
|
"modelopt_fp4",
|
||||||
|
]
|
||||||
|
|
||||||
|
if not modelopt_quantization_specified:
|
||||||
|
raise ValueError("quantize_and_serve requires ModelOpt quantization")
|
||||||
|
|
||||||
|
# quantize_and_serve is disabled due to compatibility issues
|
||||||
|
raise NotImplementedError(
|
||||||
|
"quantize_and_serve functionality is currently disabled due to compatibility issues. "
|
||||||
|
"Please use the separate quantize-then-deploy workflow instead. "
|
||||||
|
"Step 1: Quantize and export model. "
|
||||||
|
"Step 2: Deploy the exported model."
|
||||||
|
)
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = [*QUANTIZATION_METHODS]
|
supported_quantization = [*QUANTIZATION_METHODS]
|
||||||
|
|||||||
30
python/sglang/srt/configs/modelopt_config.py
Normal file
30
python/sglang/srt/configs/modelopt_config.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Configuration for NVIDIA ModelOpt quantization integration
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelOptConfig:
|
||||||
|
"""Configuration for NVIDIA ModelOpt quantization operations.
|
||||||
|
|
||||||
|
This configuration class holds parameters for ModelOpt quantization,
|
||||||
|
checkpoint management, and model export operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant: Quantization method/type (e.g., "fp8", "fp4")
|
||||||
|
checkpoint_restore_path: Path to restore ModelOpt checkpoint from
|
||||||
|
checkpoint_save_path: Path to save ModelOpt checkpoint to
|
||||||
|
export_path: Path to export quantized model in HuggingFace format
|
||||||
|
quantize_and_serve: Whether to quantize and serve in one step
|
||||||
|
"""
|
||||||
|
|
||||||
|
quant: Optional[str] = None
|
||||||
|
checkpoint_restore_path: Optional[str] = None
|
||||||
|
checkpoint_save_path: Optional[str] = None
|
||||||
|
export_path: Optional[str] = None
|
||||||
|
quantize_and_serve: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate configuration after initialization."""
|
||||||
|
# Add any validation logic if needed
|
||||||
|
pass
|
||||||
@@ -72,6 +72,7 @@ if TYPE_CHECKING:
|
|||||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
"fp8": Fp8Config,
|
"fp8": Fp8Config,
|
||||||
"blockwise_int8": BlockInt8Config,
|
"blockwise_int8": BlockInt8Config,
|
||||||
|
"modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
|
||||||
"modelopt_fp8": ModelOptFp8Config,
|
"modelopt_fp8": ModelOptFp8Config,
|
||||||
"modelopt_fp4": ModelOptFp4Config,
|
"modelopt_fp4": ModelOptFp4Config,
|
||||||
"w8a8_int8": W8A8Int8Config,
|
"w8a8_int8": W8A8Int8Config,
|
||||||
|
|||||||
@@ -161,6 +161,26 @@ class QuantizationConfig(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _modelopt_override_quantization_method(
|
||||||
|
cls, hf_quant_config, user_quant
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Shared ModelOpt quantization method override logic."""
|
||||||
|
if hf_quant_config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if this is a ModelOpt config
|
||||||
|
quant_algo = hf_quant_config.get("quant_algo", "").upper()
|
||||||
|
|
||||||
|
# If user specified generic "modelopt", auto-detect the specific method
|
||||||
|
if user_quant == "modelopt":
|
||||||
|
if "FP8" in quant_algo:
|
||||||
|
return "modelopt_fp8"
|
||||||
|
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
|
||||||
|
return "modelopt_fp4"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||||
"""Get a value from the model's quantization config."""
|
"""Get a value from the model's quantization config."""
|
||||||
|
|||||||
@@ -111,6 +111,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_config, user_quant):
|
||||||
|
"""Override quantization method based on the model's config."""
|
||||||
|
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
return "modelopt_fp8"
|
return "modelopt_fp8"
|
||||||
@@ -527,6 +532,11 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||||
self.exclude_modules = exclude_modules
|
self.exclude_modules = exclude_modules
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_config, user_quant):
|
||||||
|
"""Override quantization method based on the model's config."""
|
||||||
|
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
return "modelopt_fp4"
|
return "modelopt_fp4"
|
||||||
@@ -608,7 +618,16 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
else:
|
else:
|
||||||
kv_cache_quant_algo = "auto"
|
kv_cache_quant_algo = "auto"
|
||||||
|
|
||||||
group_size = ModelOptFp4Config.common_group_size(config)
|
group_size = config.get("group_size")
|
||||||
|
# If group_size is not at top level, try to extract from config_groups
|
||||||
|
if group_size is None:
|
||||||
|
config_groups = config.get("config_groups", {})
|
||||||
|
if config_groups:
|
||||||
|
# Get group_size from the first group's weights config
|
||||||
|
first_group = next(iter(config_groups.values()), {})
|
||||||
|
weights_config = first_group.get("weights", {})
|
||||||
|
group_size = weights_config.get("group_size")
|
||||||
|
|
||||||
exclude_modules = config.get("ignore", [])
|
exclude_modules = config.get("ignore", [])
|
||||||
else:
|
else:
|
||||||
# Fall back to nested format (hf_quant_config.json - legacy format)
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
||||||
@@ -634,15 +653,15 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
)
|
)
|
||||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||||
|
|
||||||
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
|
if group_size is None or exclude_modules is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"group_size: {group_size},"
|
f"group_size: {group_size},"
|
||||||
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
||||||
f"exclude_modules: {exclude_modules}"
|
f"exclude_modules: {exclude_modules}"
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"NVFP4 quantization requires group size and "
|
"NVFP4 quantization requires group_size and exclude_modules "
|
||||||
"kv_cache_quant_algo specified in the quantization config"
|
"specified in the quantization config"
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
is_checkpoint_nvfp4_serialized,
|
is_checkpoint_nvfp4_serialized,
|
||||||
|
|||||||
@@ -828,6 +828,16 @@ class ModelRunner:
|
|||||||
set_cuda_arch()
|
set_cuda_arch()
|
||||||
|
|
||||||
# Prepare the model config
|
# Prepare the model config
|
||||||
|
from sglang.srt.configs.modelopt_config import ModelOptConfig
|
||||||
|
|
||||||
|
modelopt_config = ModelOptConfig(
|
||||||
|
quant=self.server_args.modelopt_quant,
|
||||||
|
checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
|
||||||
|
checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
|
||||||
|
export_path=self.server_args.modelopt_export_path,
|
||||||
|
quantize_and_serve=self.server_args.quantize_and_serve,
|
||||||
|
)
|
||||||
|
|
||||||
self.load_config = LoadConfig(
|
self.load_config = LoadConfig(
|
||||||
load_format=self.server_args.load_format,
|
load_format=self.server_args.load_format,
|
||||||
download_dir=self.server_args.download_dir,
|
download_dir=self.server_args.download_dir,
|
||||||
@@ -836,6 +846,7 @@ class ModelRunner:
|
|||||||
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||||
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
|
modelopt_config=modelopt_config,
|
||||||
)
|
)
|
||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
||||||
|
|||||||
@@ -538,12 +538,21 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
|
# Handle both legacy modelopt_quant and unified quantization flags
|
||||||
|
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
||||||
|
# Legacy approach
|
||||||
|
quant_choice_str = model_config.modelopt_quant
|
||||||
|
rank0_log(f"ModelOpt quantization requested (legacy): {quant_choice_str}")
|
||||||
|
else:
|
||||||
|
# Unified approach - extract quantization type
|
||||||
|
quant_choice_str = model_config._get_modelopt_quant_type()
|
||||||
|
rank0_log(
|
||||||
|
f"ModelOpt quantization requested (unified): {model_config.quantization} -> {quant_choice_str}"
|
||||||
|
)
|
||||||
|
|
||||||
quant_choice_str = model_config.modelopt_quant
|
|
||||||
if not isinstance(quant_choice_str, str):
|
if not isinstance(quant_choice_str, str):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
|
f"Quantization type must be a string (e.g., 'fp8'), "
|
||||||
f"got {type(quant_choice_str)}"
|
f"got {type(quant_choice_str)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1764,6 +1773,7 @@ class ModelOptModelLoader(DefaultModelLoader):
|
|||||||
quant_cfg,
|
quant_cfg,
|
||||||
quantized_ckpt_restore_path: str | None = None,
|
quantized_ckpt_restore_path: str | None = None,
|
||||||
quantized_ckpt_save_path: str | None = None,
|
quantized_ckpt_save_path: str | None = None,
|
||||||
|
export_path: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Set up ModelOpt quantization for the given model.
|
Set up ModelOpt quantization for the given model.
|
||||||
@@ -1774,6 +1784,7 @@ class ModelOptModelLoader(DefaultModelLoader):
|
|||||||
quant_cfg: The quantization configuration
|
quant_cfg: The quantization configuration
|
||||||
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
|
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
|
||||||
quantized_ckpt_save_path: Path to save quantized checkpoint to
|
quantized_ckpt_save_path: Path to save quantized checkpoint to
|
||||||
|
export_path: Path to export the quantized model in HuggingFace format
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ImportError: If ModelOpt is not available
|
ImportError: If ModelOpt is not available
|
||||||
@@ -1798,6 +1809,9 @@ class ModelOptModelLoader(DefaultModelLoader):
|
|||||||
rank0_log(
|
rank0_log(
|
||||||
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Export model if path provided (even when restoring from checkpoint)
|
||||||
|
self._maybe_export_modelopt(model, export_path)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1844,9 +1858,75 @@ class ModelOptModelLoader(DefaultModelLoader):
|
|||||||
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
|
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Export model if path provided
|
||||||
|
self._maybe_export_modelopt(model, export_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
|
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
|
||||||
|
|
||||||
|
def _maybe_export_modelopt(self, model, export_path: str | None) -> None:
|
||||||
|
"""Export model to HuggingFace format if export_path is provided."""
|
||||||
|
if export_path:
|
||||||
|
try:
|
||||||
|
# Get the original model path from the model config
|
||||||
|
original_model_path = getattr(self, "_original_model_path", None)
|
||||||
|
self._export_modelopt_checkpoint(
|
||||||
|
model, export_path, original_model_path
|
||||||
|
)
|
||||||
|
rank0_log(
|
||||||
|
f"Quantized model exported to HuggingFace format at {export_path}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
rank0_log(
|
||||||
|
f"Warning: Failed to export quantized model to {export_path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _export_modelopt_checkpoint(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
export_path: str,
|
||||||
|
model_path: str = None,
|
||||||
|
trust_remote_code: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Export the quantized model to HuggingFace format using ModelOpt export API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The quantized model to export
|
||||||
|
export_path: Directory path to export the model to
|
||||||
|
model_path: Path to the original model (for tokenizer export)
|
||||||
|
trust_remote_code: Whether to trust remote code for tokenizer loading
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If ModelOpt export functionality is not available
|
||||||
|
Exception: If export fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from modelopt.torch.export import export_hf_checkpoint
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"ModelOpt export functionality is not available. "
|
||||||
|
"Please ensure you have the latest version of modelopt installed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Create export directory if it doesn't exist
|
||||||
|
os.makedirs(export_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Export the quantized model
|
||||||
|
export_hf_checkpoint(model, export_dir=export_path)
|
||||||
|
|
||||||
|
# Export the tokenizer if model_path is provided
|
||||||
|
if model_path:
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
tokenizer.save_pretrained(export_path)
|
||||||
|
rank0_log(f"Tokenizer exported to {export_path}")
|
||||||
|
except Exception as e:
|
||||||
|
rank0_log(f"Warning: Failed to export tokenizer: {e}")
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -1856,28 +1936,52 @@ class ModelOptModelLoader(DefaultModelLoader):
|
|||||||
|
|
||||||
logger.info("ModelOptModelLoader: Loading base model...")
|
logger.info("ModelOptModelLoader: Loading base model...")
|
||||||
|
|
||||||
# Use shared method from parent class to load base model
|
# Store the original model path for tokenizer export
|
||||||
|
self._original_model_path = model_config.model_path
|
||||||
|
|
||||||
|
# Check if model is already quantized
|
||||||
|
if model_config._is_already_quantized():
|
||||||
|
logger.info("Model is already quantized, loading directly...")
|
||||||
|
# Use default loading for pre-quantized models
|
||||||
|
return super().load_model(
|
||||||
|
model_config=model_config, device_config=device_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Quantize-and-serve mode has been disabled at the ModelConfig level
|
||||||
|
# All quantization now uses the standard workflow (quantize + export/save)
|
||||||
|
logger.info("Standard quantization mode: Will quantize and export/save")
|
||||||
|
return self._standard_quantization_workflow(model_config, device_config)
|
||||||
|
|
||||||
|
def _standard_quantization_workflow(
|
||||||
|
self, model_config: ModelConfig, device_config: DeviceConfig
|
||||||
|
) -> nn.Module:
|
||||||
|
"""Standard quantization workflow: quantize, save checkpoint, export, then return model."""
|
||||||
|
# Use shared method from parent class to load base model for quantization
|
||||||
model = self._load_modelopt_base_model(model_config)
|
model = self._load_modelopt_base_model(model_config)
|
||||||
|
|
||||||
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
|
# Import ModelOpt modules
|
||||||
try:
|
try:
|
||||||
import modelopt.torch.quantization as mtq
|
import modelopt.torch.quantization as mtq
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"NVIDIA Model Optimizer (modelopt) library not found. "
|
"NVIDIA Model Optimizer (modelopt) library not found. "
|
||||||
"Please install it to use 'modelopt_quant' feature."
|
"Please install it to use ModelOpt quantization."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
quant_choice_str = model_config.modelopt_quant
|
# Handle both old modelopt_quant and new unified quantization flags
|
||||||
|
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
||||||
|
# Legacy modelopt_quant flag
|
||||||
|
quant_choice_str = model_config.modelopt_quant
|
||||||
|
else:
|
||||||
|
# Unified quantization flag - extract the type (fp8/fp4)
|
||||||
|
quant_choice_str = model_config._get_modelopt_quant_type()
|
||||||
|
|
||||||
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
||||||
if not quant_cfg_name:
|
if not quant_cfg_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
|
f"Invalid quantization choice: '{quant_choice_str}'. "
|
||||||
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
|
f"Available choices: {list(QUANT_CFG_CHOICES.keys())}"
|
||||||
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
|
|
||||||
"attribute names of config objects in modelopt.torch.quantization."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1885,20 +1989,27 @@ class ModelOptModelLoader(DefaultModelLoader):
|
|||||||
quant_cfg = getattr(mtq, quant_cfg_name)
|
quant_cfg = getattr(mtq, quant_cfg_name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"ModelOpt quantization config attribute '{quant_cfg_name}' "
|
f"ModelOpt quantization config '{quant_cfg_name}' not found. "
|
||||||
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
|
"Please verify the ModelOpt library installation."
|
||||||
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
|
f"Quantizing model with ModelOpt using config: mtq.{quant_cfg_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path
|
# Get ModelOpt configuration from LoadConfig
|
||||||
quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path
|
modelopt_config = self.load_config.modelopt_config
|
||||||
|
quantized_ckpt_restore_path = (
|
||||||
|
modelopt_config.checkpoint_restore_path if modelopt_config else None
|
||||||
|
)
|
||||||
|
quantized_ckpt_save_path = (
|
||||||
|
modelopt_config.checkpoint_save_path if modelopt_config else None
|
||||||
|
)
|
||||||
|
export_path = modelopt_config.export_path if modelopt_config else None
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_config.model_path, use_fast=True
|
model_config.model_path, use_fast=True
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._setup_modelopt_quantization(
|
self._setup_modelopt_quantization(
|
||||||
model,
|
model,
|
||||||
@@ -1906,6 +2017,7 @@ class ModelOptModelLoader(DefaultModelLoader):
|
|||||||
quant_cfg,
|
quant_cfg,
|
||||||
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
|
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
|
||||||
quantized_ckpt_save_path=quantized_ckpt_save_path,
|
quantized_ckpt_save_path=quantized_ckpt_save_path,
|
||||||
|
export_path=export_path,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"ModelOpt quantization failed: {e}")
|
logger.warning(f"ModelOpt quantization failed: {e}")
|
||||||
@@ -1919,12 +2031,27 @@ def get_model_loader(
|
|||||||
) -> BaseModelLoader:
|
) -> BaseModelLoader:
|
||||||
"""Get a model loader based on the load format."""
|
"""Get a model loader based on the load format."""
|
||||||
|
|
||||||
|
if model_config and (
|
||||||
|
(hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant)
|
||||||
|
or model_config.quantization in ["modelopt_fp8", "modelopt_fp4", "modelopt"]
|
||||||
|
):
|
||||||
|
logger.info("Using ModelOptModelLoader due to ModelOpt quantization config.")
|
||||||
|
return ModelOptModelLoader(load_config)
|
||||||
|
|
||||||
|
# Use ModelOptModelLoader for unified quantization flags
|
||||||
if (
|
if (
|
||||||
model_config
|
model_config
|
||||||
and hasattr(model_config, "modelopt_quant")
|
and hasattr(model_config, "quantization")
|
||||||
and model_config.modelopt_quant
|
and model_config.quantization in ["modelopt_fp8", "modelopt_fp4"]
|
||||||
):
|
):
|
||||||
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
|
if model_config._is_already_quantized():
|
||||||
|
logger.info(
|
||||||
|
f"Using ModelOptModelLoader for pre-quantized model: {model_config.quantization}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Using ModelOptModelLoader for quantization: {model_config.quantization}"
|
||||||
|
)
|
||||||
return ModelOptModelLoader(load_config)
|
return ModelOptModelLoader(load_config)
|
||||||
|
|
||||||
if isinstance(load_config.load_format, type):
|
if isinstance(load_config.load_format, type):
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ QUANTIZATION_CHOICES = [
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"gguf",
|
"gguf",
|
||||||
"modelopt",
|
"modelopt",
|
||||||
|
"modelopt_fp8",
|
||||||
"modelopt_fp4",
|
"modelopt_fp4",
|
||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
"w8a8_int8",
|
"w8a8_int8",
|
||||||
@@ -192,6 +193,8 @@ class ServerArgs:
|
|||||||
modelopt_quant: Optional[Union[str, Dict]] = None
|
modelopt_quant: Optional[Union[str, Dict]] = None
|
||||||
modelopt_checkpoint_restore_path: Optional[str] = None
|
modelopt_checkpoint_restore_path: Optional[str] = None
|
||||||
modelopt_checkpoint_save_path: Optional[str] = None
|
modelopt_checkpoint_save_path: Optional[str] = None
|
||||||
|
modelopt_export_path: Optional[str] = None
|
||||||
|
quantize_and_serve: bool = False
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
is_embedding: bool = False
|
is_embedding: bool = False
|
||||||
enable_multimodal: Optional[bool] = None
|
enable_multimodal: Optional[bool] = None
|
||||||
@@ -1743,6 +1746,22 @@ class ServerArgs:
|
|||||||
help="Path to save the ModelOpt quantized checkpoint after quantization. "
|
help="Path to save the ModelOpt quantized checkpoint after quantization. "
|
||||||
"This allows reusing the quantized model in future runs.",
|
"This allows reusing the quantized model in future runs.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--modelopt-export-path",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.modelopt_export_path,
|
||||||
|
help="Path to export the quantized model in HuggingFace format after ModelOpt quantization. "
|
||||||
|
"The exported model can then be used directly with SGLang for inference. "
|
||||||
|
"If not provided, the model will not be exported.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize-and-serve",
|
||||||
|
action="store_true",
|
||||||
|
default=ServerArgs.quantize_and_serve,
|
||||||
|
help="Quantize the model with ModelOpt and immediately serve it without exporting. "
|
||||||
|
"This is useful for development and prototyping. For production, it's recommended "
|
||||||
|
"to use separate quantization and deployment steps.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -2411,6 +2411,29 @@ def retry(
|
|||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
|
def has_hf_quant_config(model_path: str) -> bool:
|
||||||
|
"""Check if the model path contains hf_quant_config.json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model, can be local path or remote URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if hf_quant_config.json exists, False otherwise.
|
||||||
|
"""
|
||||||
|
if is_remote_url(model_path):
|
||||||
|
try:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
hf_api = HfApi()
|
||||||
|
return hf_api.file_exists(model_path, "hf_quant_config.json")
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
import os
|
||||||
|
|
||||||
|
return os.path.exists(os.path.join(model_path, "hf_quant_config.json"))
|
||||||
|
|
||||||
|
|
||||||
def flatten_nested_list(nested_list):
|
def flatten_nested_list(nested_list):
|
||||||
if isinstance(nested_list, list):
|
if isinstance(nested_list, list):
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -135,6 +135,8 @@ suites = {
|
|||||||
TestFile("test_vision_chunked_prefill.py", 175),
|
TestFile("test_vision_chunked_prefill.py", 175),
|
||||||
TestFile("test_vision_openai_server_a.py", 918),
|
TestFile("test_vision_openai_server_a.py", 918),
|
||||||
TestFile("test_vlm_input_format.py", 300),
|
TestFile("test_vlm_input_format.py", 300),
|
||||||
|
TestFile("test_modelopt_loader.py", 30),
|
||||||
|
TestFile("test_modelopt_export.py", 30),
|
||||||
],
|
],
|
||||||
"per-commit-2-gpu": [
|
"per-commit-2-gpu": [
|
||||||
TestFile("ep/test_moe_ep.py", 140),
|
TestFile("ep/test_moe_ep.py", 140),
|
||||||
|
|||||||
353
test/srt/test_modelopt_export.py
Normal file
353
test/srt/test_modelopt_export.py
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for ModelOpt export functionality in SGLang.
|
||||||
|
|
||||||
|
These tests verify the integration of ModelOpt export API with SGLang's model loading
|
||||||
|
and quantization workflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.model_loader.loader import ModelOptModelLoader
|
||||||
|
|
||||||
|
# Note: PYTHONPATH=python should be set when running tests
|
||||||
|
|
||||||
|
# Check if modelopt is available
|
||||||
|
try:
|
||||||
|
import modelopt
|
||||||
|
|
||||||
|
MODELOPT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MODELOPT_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelOptExport(unittest.TestCase):
|
||||||
|
"""Test suite for ModelOpt export functionality."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
# Mock distributed functionality to avoid initialization errors
|
||||||
|
self.mock_tp_rank = patch(
|
||||||
|
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
self.mock_tp_rank.start()
|
||||||
|
|
||||||
|
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
|
||||||
|
self.mock_rank0_log.start()
|
||||||
|
|
||||||
|
# Mock logger to avoid issues
|
||||||
|
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
|
||||||
|
self.mock_logger.start()
|
||||||
|
|
||||||
|
# Mock all distributed functions that might be called
|
||||||
|
self.mock_get_tp_group = patch(
|
||||||
|
"sglang.srt.distributed.parallel_state.get_tp_group"
|
||||||
|
)
|
||||||
|
self.mock_get_tp_group.start()
|
||||||
|
|
||||||
|
# Mock model parallel initialization check
|
||||||
|
self.mock_mp_is_initialized = patch(
|
||||||
|
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
|
||||||
|
return_value=True,
|
||||||
|
)
|
||||||
|
self.mock_mp_is_initialized.start()
|
||||||
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
|
self.export_dir = os.path.join(self.temp_dir, "exported_model")
|
||||||
|
self.checkpoint_dir = os.path.join(self.temp_dir, "checkpoint")
|
||||||
|
|
||||||
|
# Mock model
|
||||||
|
self.mock_model = Mock(spec=torch.nn.Module)
|
||||||
|
self.mock_model.device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
# Mock tokenizer
|
||||||
|
self.mock_tokenizer = Mock()
|
||||||
|
|
||||||
|
# Mock quantization config
|
||||||
|
self.mock_quant_cfg = Mock()
|
||||||
|
|
||||||
|
# Create ModelOptModelLoader instance
|
||||||
|
self.load_config = LoadConfig()
|
||||||
|
self.model_loader = ModelOptModelLoader(self.load_config)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up test fixtures."""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
# Stop mocks
|
||||||
|
self.mock_tp_rank.stop()
|
||||||
|
self.mock_rank0_log.stop()
|
||||||
|
self.mock_logger.stop()
|
||||||
|
self.mock_get_tp_group.stop()
|
||||||
|
self.mock_mp_is_initialized.stop()
|
||||||
|
|
||||||
|
def _create_mock_export_files(self, export_dir: str):
|
||||||
|
"""Create mock export files for testing validation."""
|
||||||
|
os.makedirs(export_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create config.json
|
||||||
|
config = {
|
||||||
|
"model_type": "test_model",
|
||||||
|
"architectures": ["TestModel"],
|
||||||
|
"quantization_config": {
|
||||||
|
"quant_method": "modelopt",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with open(os.path.join(export_dir, "config.json"), "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
# Create tokenizer_config.json
|
||||||
|
tokenizer_config = {"tokenizer_class": "TestTokenizer"}
|
||||||
|
with open(os.path.join(export_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(tokenizer_config, f)
|
||||||
|
|
||||||
|
# Create model file
|
||||||
|
with open(os.path.join(export_dir, "model.safetensors"), "w") as f:
|
||||||
|
f.write("mock_model_data")
|
||||||
|
|
||||||
|
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||||
|
@patch("sglang.srt.model_loader.loader.os.makedirs")
|
||||||
|
@patch("modelopt.torch.export.export_hf_checkpoint")
|
||||||
|
def test_export_modelopt_checkpoint_success(self, mock_export, mock_makedirs):
|
||||||
|
"""Test successful model export."""
|
||||||
|
# Arrange
|
||||||
|
mock_export.return_value = None
|
||||||
|
mock_makedirs.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
self.model_loader._export_modelopt_checkpoint(self.mock_model, self.export_dir)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_makedirs.assert_called_once_with(self.export_dir, exist_ok=True)
|
||||||
|
mock_export.assert_called_once_with(self.mock_model, export_dir=self.export_dir)
|
||||||
|
|
||||||
|
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||||
|
@patch("modelopt.torch.opt.restore")
|
||||||
|
@patch("modelopt.torch.quantization.utils.is_quantized")
|
||||||
|
def test_setup_quantization_with_export_from_checkpoint(
|
||||||
|
self, mock_is_quantized, mock_restore
|
||||||
|
):
|
||||||
|
"""Test export functionality when restoring from checkpoint."""
|
||||||
|
# Arrange
|
||||||
|
mock_is_quantized.return_value = False
|
||||||
|
mock_restore.return_value = None
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
self.model_loader, "_export_modelopt_checkpoint"
|
||||||
|
) as mock_export:
|
||||||
|
# Act
|
||||||
|
self.model_loader._setup_modelopt_quantization(
|
||||||
|
self.mock_model,
|
||||||
|
self.mock_tokenizer,
|
||||||
|
self.mock_quant_cfg,
|
||||||
|
quantized_ckpt_restore_path=self.checkpoint_dir,
|
||||||
|
export_path=self.export_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_restore.assert_called_once_with(self.mock_model, self.checkpoint_dir)
|
||||||
|
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
|
||||||
|
|
||||||
|
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||||
|
@patch("modelopt.torch.quantization.quantize")
|
||||||
|
@patch("modelopt.torch.quantization.print_quant_summary")
|
||||||
|
@patch("modelopt.torch.quantization.utils.is_quantized")
|
||||||
|
@patch("modelopt.torch.utils.dataset_utils.get_dataset_dataloader")
|
||||||
|
@patch("modelopt.torch.utils.dataset_utils.create_forward_loop")
|
||||||
|
def test_setup_quantization_with_export_after_calibration(
|
||||||
|
self,
|
||||||
|
mock_create_loop,
|
||||||
|
mock_get_dataloader,
|
||||||
|
mock_is_quantized,
|
||||||
|
mock_print_summary,
|
||||||
|
mock_quantize,
|
||||||
|
):
|
||||||
|
"""Test export functionality after calibration-based quantization."""
|
||||||
|
# Arrange
|
||||||
|
mock_is_quantized.return_value = False
|
||||||
|
mock_dataloader = Mock()
|
||||||
|
mock_get_dataloader.return_value = mock_dataloader
|
||||||
|
mock_calibrate_loop = Mock()
|
||||||
|
mock_create_loop.return_value = mock_calibrate_loop
|
||||||
|
mock_quantize.return_value = None
|
||||||
|
mock_print_summary.return_value = None
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
self.model_loader, "_export_modelopt_checkpoint"
|
||||||
|
) as mock_export:
|
||||||
|
# Act
|
||||||
|
self.model_loader._setup_modelopt_quantization(
|
||||||
|
self.mock_model,
|
||||||
|
self.mock_tokenizer,
|
||||||
|
self.mock_quant_cfg,
|
||||||
|
export_path=self.export_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_quantize.assert_called_once_with(
|
||||||
|
self.mock_model, self.mock_quant_cfg, forward_loop=mock_calibrate_loop
|
||||||
|
)
|
||||||
|
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
|
||||||
|
|
||||||
|
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||||
|
def test_setup_quantization_without_export(self):
|
||||||
|
"""Test quantization setup without export path specified."""
|
||||||
|
with patch("modelopt.torch.quantization.utils.is_quantized", return_value=True):
|
||||||
|
# Act
|
||||||
|
with patch.object(
|
||||||
|
self.model_loader, "_export_modelopt_checkpoint"
|
||||||
|
) as mock_export:
|
||||||
|
self.model_loader._setup_modelopt_quantization(
|
||||||
|
self.mock_model,
|
||||||
|
self.mock_tokenizer,
|
||||||
|
self.mock_quant_cfg,
|
||||||
|
export_path=None, # No export path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_export.assert_not_called()
|
||||||
|
|
||||||
|
def test_quantize_and_serve_config_validation(self):
|
||||||
|
"""Test that quantize_and_serve is properly disabled."""
|
||||||
|
# Test that quantize-and-serve mode raises NotImplementedError
|
||||||
|
with self.assertRaises(NotImplementedError) as context:
|
||||||
|
ModelConfig(
|
||||||
|
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
quantization="modelopt_fp8",
|
||||||
|
quantize_and_serve=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the error message contains helpful instructions
|
||||||
|
error_msg = str(context.exception)
|
||||||
|
self.assertIn("disabled due to compatibility issues", error_msg)
|
||||||
|
self.assertIn("separate quantize-then-deploy workflow", error_msg)
|
||||||
|
|
||||||
|
# Test invalid configuration - no quantization
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
ModelConfig(
|
||||||
|
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
quantize_and_serve=True,
|
||||||
|
)
|
||||||
|
self.assertIn("requires ModelOpt quantization", str(context.exception))
|
||||||
|
|
||||||
|
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||||
|
def test_standard_workflow_selection(self):
|
||||||
|
"""Test that standard workflow is selected by default."""
|
||||||
|
with patch(
|
||||||
|
"modelopt.torch.quantization.utils.is_quantized", return_value=False
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
self.model_loader, "_standard_quantization_workflow"
|
||||||
|
) as mock_standard:
|
||||||
|
with patch.object(self.model_loader, "_load_modelopt_base_model"):
|
||||||
|
mock_standard.return_value = Mock()
|
||||||
|
|
||||||
|
# Create model config without quantize_and_serve
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
quantization="modelopt_fp8",
|
||||||
|
quantize_and_serve=False,
|
||||||
|
)
|
||||||
|
device_config = DeviceConfig()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
self.model_loader.load_model(
|
||||||
|
model_config=model_config,
|
||||||
|
device_config=device_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_standard.assert_called_once_with(model_config, device_config)
|
||||||
|
|
||||||
|
def _get_export_info(self, export_dir: str) -> dict:
|
||||||
|
"""Get information about an exported model."""
|
||||||
|
if not self._validate_export(export_dir):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_path = os.path.join(export_dir, "config.json")
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model_type": config.get("model_type", "unknown"),
|
||||||
|
"architectures": config.get("architectures", []),
|
||||||
|
"quantization_config": config.get("quantization_config", {}),
|
||||||
|
"export_dir": export_dir,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||||
|
class TestModelOptExportIntegration(unittest.TestCase):
|
||||||
|
"""Integration tests for ModelOpt export with full model loading workflow."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up integration test fixtures."""
|
||||||
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
|
self.export_dir = os.path.join(self.temp_dir, "exported_model")
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up integration test fixtures."""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
@patch("sglang.srt.model_loader.loader.get_model_architecture")
|
||||||
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
||||||
|
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
||||||
|
def test_full_workflow_with_export(self, mock_model, mock_tokenizer, mock_arch):
|
||||||
|
"""Test the complete workflow from model config to export."""
|
||||||
|
# Arrange
|
||||||
|
mock_arch.return_value = ("TestModel", "TestConfig")
|
||||||
|
mock_tokenizer.return_value = Mock()
|
||||||
|
mock_model.return_value = Mock(spec=torch.nn.Module)
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
modelopt_quant="fp8",
|
||||||
|
modelopt_export_path=self.export_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_config = LoadConfig()
|
||||||
|
device_config = DeviceConfig()
|
||||||
|
|
||||||
|
# Mock the quantization and export process
|
||||||
|
with patch.object(
|
||||||
|
ModelOptModelLoader, "_setup_modelopt_quantization"
|
||||||
|
) as mock_setup:
|
||||||
|
with patch.object(
|
||||||
|
ModelOptModelLoader, "_load_modelopt_base_model"
|
||||||
|
) as mock_load_base:
|
||||||
|
mock_load_base.return_value = mock_model.return_value
|
||||||
|
|
||||||
|
# Act
|
||||||
|
model_loader = ModelOptModelLoader(load_config)
|
||||||
|
result = model_loader.load_model(
|
||||||
|
model_config=model_config,
|
||||||
|
device_config=device_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
mock_setup.assert_called_once()
|
||||||
|
# Verify export_path was passed to setup
|
||||||
|
args, kwargs = mock_setup.call_args
|
||||||
|
self.assertEqual(kwargs.get("export_path"), self.export_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -12,8 +12,17 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# Add the sglang path for testing
|
# Note: PYTHONPATH=python should be set when running tests
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
|
|
||||||
|
# Constants for calibration parameters to avoid hard-coded values
|
||||||
|
CALIBRATION_BATCH_SIZE = 36
|
||||||
|
CALIBRATION_NUM_SAMPLES = 512
|
||||||
|
DEFAULT_DEVICE = "cuda:0"
|
||||||
|
|
||||||
|
# Constants for calibration parameters to avoid hard-coded values
|
||||||
|
CALIBRATION_BATCH_SIZE = 36
|
||||||
|
CALIBRATION_NUM_SAMPLES = 512
|
||||||
|
DEFAULT_DEVICE = "cuda:0"
|
||||||
|
|
||||||
from sglang.srt.configs.device_config import DeviceConfig
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
@@ -28,18 +37,63 @@ class TestModelOptModelLoader(CustomTestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test fixtures."""
|
"""Set up test fixtures."""
|
||||||
|
# Mock distributed functionality to avoid initialization errors
|
||||||
|
self.mock_tp_rank = patch(
|
||||||
|
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
self.mock_tp_rank.start()
|
||||||
|
|
||||||
|
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
|
||||||
|
self.mock_rank0_log.start()
|
||||||
|
|
||||||
|
# Mock logger to avoid issues
|
||||||
|
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
|
||||||
|
self.mock_logger.start()
|
||||||
|
|
||||||
|
# Mock all distributed functions that might be called
|
||||||
|
self.mock_get_tp_group = patch(
|
||||||
|
"sglang.srt.distributed.parallel_state.get_tp_group"
|
||||||
|
)
|
||||||
|
self.mock_get_tp_group.start()
|
||||||
|
|
||||||
|
# Mock model parallel initialization check
|
||||||
|
self.mock_mp_is_initialized = patch(
|
||||||
|
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
|
||||||
|
return_value=True,
|
||||||
|
)
|
||||||
|
self.mock_mp_is_initialized.start()
|
||||||
|
|
||||||
self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
self.load_config = LoadConfig()
|
self.load_config = LoadConfig()
|
||||||
self.device_config = DeviceConfig(device="cuda")
|
self.device_config = DeviceConfig(device="cuda")
|
||||||
|
|
||||||
# Create a basic model config with modelopt_quant
|
# Create a basic model config with unified quantization flag
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
model_path=self.model_path, modelopt_quant="fp8"
|
model_path=self.model_path,
|
||||||
|
quantization="modelopt_fp8", # Use unified quantization approach
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also create a unified quantization config for new tests
|
||||||
|
self.unified_model_config = ModelConfig(
|
||||||
|
model_path=self.model_path, quantization="modelopt_fp8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock base model
|
# Mock base model
|
||||||
self.mock_base_model = MagicMock(spec=nn.Module)
|
self.mock_base_model = MagicMock(spec=nn.Module)
|
||||||
self.mock_base_model.eval.return_value = self.mock_base_model
|
self.mock_base_model.eval.return_value = self.mock_base_model
|
||||||
|
self.mock_base_model.device = (
|
||||||
|
DEFAULT_DEVICE # Add device attribute for calibration tests
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up test fixtures."""
|
||||||
|
# Stop mocks
|
||||||
|
self.mock_tp_rank.stop()
|
||||||
|
self.mock_rank0_log.stop()
|
||||||
|
self.mock_logger.stop()
|
||||||
|
self.mock_get_tp_group.stop()
|
||||||
|
self.mock_mp_is_initialized.stop()
|
||||||
|
|
||||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||||
@patch("sglang.srt.model_loader.loader.logger")
|
@patch("sglang.srt.model_loader.loader.logger")
|
||||||
@@ -66,7 +120,7 @@ class TestModelOptModelLoader(CustomTestCase):
|
|||||||
model = self.mock_base_model
|
model = self.mock_base_model
|
||||||
|
|
||||||
# Simulate the quantization config lookup
|
# Simulate the quantization config lookup
|
||||||
quant_choice_str = model_config.modelopt_quant
|
quant_choice_str = model_config._get_modelopt_quant_type()
|
||||||
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
||||||
|
|
||||||
if not quant_cfg_name:
|
if not quant_cfg_name:
|
||||||
@@ -123,6 +177,305 @@ class TestModelOptModelLoader(CustomTestCase):
|
|||||||
# Verify we get back the expected model
|
# Verify we get back the expected model
|
||||||
self.assertEqual(result_model, self.mock_base_model)
|
self.assertEqual(result_model, self.mock_base_model)
|
||||||
|
|
||||||
|
@patch("sglang.srt.model_loader.loader.logger")
|
||||||
|
def test_missing_modelopt_import(self, mock_logger):
|
||||||
|
"""Test error handling when modelopt library is not available."""
|
||||||
|
|
||||||
|
loader = ModelOptModelLoader(self.load_config)
|
||||||
|
|
||||||
|
# Mock the base model loader method
|
||||||
|
with patch.object(
|
||||||
|
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||||
|
):
|
||||||
|
# Simulate missing modelopt by making import fail
|
||||||
|
original_import = __import__
|
||||||
|
|
||||||
|
def mock_import(name, *args, **kwargs):
|
||||||
|
if name.startswith("modelopt"):
|
||||||
|
raise ImportError("No module named 'modelopt'")
|
||||||
|
# Return default import behavior for other modules
|
||||||
|
return original_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
with patch("builtins.__import__", side_effect=mock_import):
|
||||||
|
# Expect ImportError to be raised and logged
|
||||||
|
with self.assertRaises(ImportError):
|
||||||
|
loader.load_model(
|
||||||
|
model_config=self.model_config, device_config=self.device_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error logging
|
||||||
|
mock_logger.error.assert_called_with(
|
||||||
|
"NVIDIA Model Optimizer (modelopt) library not found. "
|
||||||
|
"Please install it to use ModelOpt quantization."
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||||
|
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||||
|
@patch("sglang.srt.model_loader.loader.logger")
|
||||||
|
def test_calibration_workflow_integration(self, mock_logger, mock_auto_tokenizer):
|
||||||
|
"""Test end-to-end calibration workflow integration."""
|
||||||
|
|
||||||
|
loader = ModelOptModelLoader(self.load_config)
|
||||||
|
|
||||||
|
# Mock tokenizer
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.padding_side = "right"
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
# Mock modelopt modules
|
||||||
|
mock_mtq = MagicMock()
|
||||||
|
mock_mto = MagicMock()
|
||||||
|
mock_dataset_utils = MagicMock()
|
||||||
|
|
||||||
|
# Configure quantization config
|
||||||
|
mock_fp8_cfg = MagicMock()
|
||||||
|
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||||
|
|
||||||
|
# Configure dataset utilities
|
||||||
|
mock_calib_dataloader = MagicMock()
|
||||||
|
mock_calibrate_loop = MagicMock()
|
||||||
|
mock_dataset_utils.get_dataset_dataloader.return_value = mock_calib_dataloader
|
||||||
|
mock_dataset_utils.create_forward_loop.return_value = mock_calibrate_loop
|
||||||
|
|
||||||
|
# Configure model as not quantized initially
|
||||||
|
mock_is_quantized = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||||
|
):
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"modelopt": MagicMock(),
|
||||||
|
"modelopt.torch": MagicMock(),
|
||||||
|
"modelopt.torch.opt": mock_mto,
|
||||||
|
"modelopt.torch.quantization": mock_mtq,
|
||||||
|
"modelopt.torch.quantization.utils": MagicMock(
|
||||||
|
is_quantized=mock_is_quantized
|
||||||
|
),
|
||||||
|
"modelopt.torch.utils": MagicMock(),
|
||||||
|
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
# Execute the load_model method to test the full workflow
|
||||||
|
result_model = loader.load_model(
|
||||||
|
model_config=self.model_config, device_config=self.device_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the model loading was successful
|
||||||
|
self.assertEqual(result_model, self.mock_base_model)
|
||||||
|
|
||||||
|
# Verify key calibration components were used
|
||||||
|
# Note: We can't easily verify the exact calls due to dynamic imports,
|
||||||
|
# but we can verify the workflow completed successfully
|
||||||
|
|
||||||
|
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||||
|
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||||
|
@patch("sglang.srt.model_loader.loader.logger")
|
||||||
|
def test_quantized_checkpoint_restore(self, mock_logger, mock_auto_tokenizer):
|
||||||
|
"""Test restoring from a quantized checkpoint."""
|
||||||
|
|
||||||
|
# Create model config with checkpoint restore path
|
||||||
|
config_with_restore = ModelConfig(
|
||||||
|
model_path=self.model_path,
|
||||||
|
quantization="modelopt_fp8",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create load config with checkpoint restore path
|
||||||
|
load_config_with_restore = LoadConfig(
|
||||||
|
modelopt_checkpoint_restore_path="/path/to/quantized/checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = ModelOptModelLoader(load_config_with_restore)
|
||||||
|
|
||||||
|
# Mock tokenizer
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
# Mock modelopt modules
|
||||||
|
mock_mtq = MagicMock()
|
||||||
|
mock_mto = MagicMock()
|
||||||
|
|
||||||
|
# Configure quantization config
|
||||||
|
mock_fp8_cfg = MagicMock()
|
||||||
|
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||||
|
|
||||||
|
# Configure model as not quantized initially
|
||||||
|
mock_is_quantized = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||||
|
):
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"modelopt": MagicMock(),
|
||||||
|
"modelopt.torch": MagicMock(),
|
||||||
|
"modelopt.torch.opt": mock_mto,
|
||||||
|
"modelopt.torch.quantization": mock_mtq,
|
||||||
|
"modelopt.torch.quantization.utils": MagicMock(
|
||||||
|
is_quantized=mock_is_quantized
|
||||||
|
),
|
||||||
|
},
|
||||||
|
):
|
||||||
|
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
|
||||||
|
# Mock the _setup_modelopt_quantization to simulate checkpoint restore
|
||||||
|
def mock_setup_quantization(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
quant_cfg,
|
||||||
|
quantized_ckpt_restore_path=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if quantized_ckpt_restore_path:
|
||||||
|
mock_mto.restore(model, quantized_ckpt_restore_path)
|
||||||
|
print(
|
||||||
|
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
mock_setup.side_effect = mock_setup_quantization
|
||||||
|
|
||||||
|
# Execute the load_model method
|
||||||
|
result_model = loader.load_model(
|
||||||
|
model_config=config_with_restore,
|
||||||
|
device_config=self.device_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the setup was called with restore path
|
||||||
|
mock_setup.assert_called_once()
|
||||||
|
call_args = mock_setup.call_args
|
||||||
|
# Check that the restore path was passed correctly
|
||||||
|
self.assertIn("quantized_ckpt_restore_path", call_args[1])
|
||||||
|
self.assertEqual(
|
||||||
|
call_args[1]["quantized_ckpt_restore_path"],
|
||||||
|
"/path/to/quantized/checkpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify restore was called
|
||||||
|
mock_mto.restore.assert_called_once_with(
|
||||||
|
self.mock_base_model, "/path/to/quantized/checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get the expected model back
|
||||||
|
self.assertEqual(result_model, self.mock_base_model)
|
||||||
|
|
||||||
|
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||||
|
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||||
|
@patch("sglang.srt.model_loader.loader.logger")
|
||||||
|
def test_quantized_checkpoint_save(self, mock_logger, mock_auto_tokenizer):
|
||||||
|
"""Test saving quantized checkpoint after calibration."""
|
||||||
|
|
||||||
|
# Create model config with checkpoint save path
|
||||||
|
config_with_save = ModelConfig(
|
||||||
|
model_path=self.model_path,
|
||||||
|
quantization="modelopt_fp8",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create load config with checkpoint save path
|
||||||
|
load_config_with_save = LoadConfig(
|
||||||
|
modelopt_checkpoint_save_path="/path/to/save/checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = ModelOptModelLoader(load_config_with_save)
|
||||||
|
|
||||||
|
# Mock tokenizer
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
# Mock modelopt modules
|
||||||
|
mock_mtq = MagicMock()
|
||||||
|
mock_mto = MagicMock()
|
||||||
|
mock_dataset_utils = MagicMock()
|
||||||
|
|
||||||
|
# Configure quantization config
|
||||||
|
mock_fp8_cfg = MagicMock()
|
||||||
|
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||||
|
|
||||||
|
# Configure model as not quantized initially
|
||||||
|
mock_is_quantized = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||||
|
):
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"modelopt": MagicMock(),
|
||||||
|
"modelopt.torch": MagicMock(),
|
||||||
|
"modelopt.torch.opt": mock_mto,
|
||||||
|
"modelopt.torch.quantization": mock_mtq,
|
||||||
|
"modelopt.torch.quantization.utils": MagicMock(
|
||||||
|
is_quantized=mock_is_quantized
|
||||||
|
),
|
||||||
|
"modelopt.torch.utils": MagicMock(),
|
||||||
|
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
|
||||||
|
# Mock the _setup_modelopt_quantization to simulate checkpoint save
|
||||||
|
def mock_setup_quantization(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
quant_cfg,
|
||||||
|
quantized_ckpt_save_path=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Simulate calibration and quantization
|
||||||
|
mock_mtq.quantize(model, quant_cfg, forward_loop=MagicMock())
|
||||||
|
mock_mtq.print_quant_summary(model)
|
||||||
|
|
||||||
|
# Save checkpoint if path provided
|
||||||
|
if quantized_ckpt_save_path:
|
||||||
|
mock_mto.save(model, quantized_ckpt_save_path)
|
||||||
|
print(
|
||||||
|
f"Quantized model saved to {quantized_ckpt_save_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_setup.side_effect = mock_setup_quantization
|
||||||
|
|
||||||
|
# Execute the load_model method
|
||||||
|
result_model = loader.load_model(
|
||||||
|
model_config=config_with_save, device_config=self.device_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the setup was called with save path
|
||||||
|
mock_setup.assert_called_once()
|
||||||
|
call_args = mock_setup.call_args
|
||||||
|
# Check that the save path was passed correctly
|
||||||
|
self.assertIn("quantized_ckpt_save_path", call_args[1])
|
||||||
|
self.assertEqual(
|
||||||
|
call_args[1]["quantized_ckpt_save_path"],
|
||||||
|
"/path/to/save/checkpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify save was called
|
||||||
|
mock_mto.save.assert_called_once_with(
|
||||||
|
self.mock_base_model, "/path/to/save/checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get the expected model back
|
||||||
|
self.assertEqual(result_model, self.mock_base_model)
|
||||||
|
|
||||||
|
def test_unified_quantization_flag_support(self):
|
||||||
|
"""Test that ModelOptModelLoader supports unified quantization flags."""
|
||||||
|
# Test modelopt_fp8
|
||||||
|
config_fp8 = ModelConfig(
|
||||||
|
model_path=self.model_path, quantization="modelopt_fp8"
|
||||||
|
)
|
||||||
|
self.assertEqual(config_fp8._get_modelopt_quant_type(), "fp8")
|
||||||
|
|
||||||
|
# Test modelopt_fp4
|
||||||
|
config_fp4 = ModelConfig(
|
||||||
|
model_path=self.model_path, quantization="modelopt_fp4"
|
||||||
|
)
|
||||||
|
self.assertEqual(config_fp4._get_modelopt_quant_type(), "nvfp4")
|
||||||
|
|
||||||
|
# Test auto-detection
|
||||||
|
config_auto = ModelConfig(model_path=self.model_path, quantization="modelopt")
|
||||||
|
# Should default to fp8 when no config is detected
|
||||||
|
self.assertEqual(config_auto._get_modelopt_quant_type(), "fp8")
|
||||||
|
|
||||||
|
|
||||||
class TestModelOptLoaderIntegration(CustomTestCase):
|
class TestModelOptLoaderIntegration(CustomTestCase):
|
||||||
"""Integration tests for ModelOptModelLoader with Engine API."""
|
"""Integration tests for ModelOptModelLoader with Engine API."""
|
||||||
|
|||||||
Reference in New Issue
Block a user