Enable native ModelOpt quantization support (3/3) (#10154)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
@@ -110,6 +110,157 @@ python3 -m sglang.launch_server \
|
||||
--port 30000 --host 0.0.0.0
|
||||
```
|
||||
|
||||
#### Using [NVIDIA ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
|
||||
|
||||
NVIDIA Model Optimizer (ModelOpt) provides advanced quantization techniques optimized for NVIDIA hardware. SGLang includes a streamlined workflow for quantizing models with ModelOpt and automatically exporting them for deployment.
|
||||
|
||||
##### Installation
|
||||
|
||||
First, install ModelOpt. You can either install it directly or as an optional SGLang dependency:
|
||||
|
||||
```bash
|
||||
# Option 1: Install ModelOpt directly
|
||||
pip install nvidia-modelopt
|
||||
|
||||
# Option 2: Install SGLang with ModelOpt support (recommended)
|
||||
pip install sglang[modelopt]
|
||||
```
|
||||
|
||||
##### Quantization and Export Workflow
|
||||
|
||||
SGLang provides an example script that demonstrates the complete ModelOpt quantization and export workflow:
|
||||
|
||||
```bash
|
||||
# Quantize and export a model using ModelOpt FP8 quantization
|
||||
python examples/usage/modelopt_quantize_and_export.py quantize \
|
||||
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
|
||||
--export-dir ./quantized_tinyllama_fp8 \
|
||||
--quantization-method modelopt_fp8
|
||||
|
||||
# For FP4 quantization
|
||||
python examples/usage/modelopt_quantize_and_export.py quantize \
|
||||
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
|
||||
--export-dir ./quantized_tinyllama_fp4 \
|
||||
--quantization-method modelopt_fp4
|
||||
```
|
||||
|
||||
##### Available Quantization Methods
|
||||
|
||||
- `modelopt_fp8`: FP8 quantization with optimal performance on NVIDIA Hopper and Blackwell GPUs
|
||||
- `modelopt_fp4`: FP4 quantization with optimal performance on Nvidia Blackwell GPUs
|
||||
|
||||
##### Python API Usage
|
||||
|
||||
You can also use ModelOpt quantization programmatically:
|
||||
|
||||
```python
|
||||
import sglang as sgl
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.model_loader.loader import get_model_loader
|
||||
|
||||
# Configure model with ModelOpt quantization and export
|
||||
model_config = ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
quantization="modelopt_fp8", # or "modelopt_fp4"
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
load_config = LoadConfig(
|
||||
modelopt_export_path="./exported_model",
|
||||
modelopt_checkpoint_save_path="./checkpoint.pth", # optional, fake quantized checkpoint
|
||||
)
|
||||
device_config = DeviceConfig(device="cuda")
|
||||
|
||||
# Load and quantize the model (export happens automatically)
|
||||
model_loader = get_model_loader(load_config, model_config)
|
||||
quantized_model = model_loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
)
|
||||
```
|
||||
|
||||
##### Deploying Quantized Models
|
||||
|
||||
After quantization and export, you can deploy the model with SGLang:
|
||||
|
||||
```bash
|
||||
# Deploy the exported quantized model
|
||||
python -m sglang.launch_server \
|
||||
--model-path ./quantized_tinyllama_fp8 \
|
||||
--quantization modelopt \
|
||||
--port 30000 --host 0.0.0.0
|
||||
```
|
||||
|
||||
Or using the Python API:
|
||||
|
||||
```python
|
||||
import sglang as sgl
|
||||
|
||||
# Deploy exported ModelOpt quantized model
|
||||
llm = sgl.Engine(
|
||||
model_path="./quantized_tinyllama_fp8",
|
||||
quantization="modelopt"
|
||||
)
|
||||
|
||||
# Run inference
|
||||
prompts = ["Hello, how are you?", "What is the capital of France?"]
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 100}
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
print(f"Prompt: {prompts[i]}")
|
||||
print(f"Output: {output.outputs[0].text}")
|
||||
```
|
||||
|
||||
##### Advanced Features
|
||||
|
||||
**Checkpoint Management**: Save and restore fake quantized checkpoints for reuse:
|
||||
|
||||
```bash
|
||||
# Save the fake quantized checkpoint during quantization
|
||||
python examples/usage/modelopt_quantize_and_export.py quantize \
|
||||
--model-path meta-llama/Llama-3.2-1B-Instruct \
|
||||
--export-dir ./quantized_model \
|
||||
--quantization-method modelopt_fp8 \
|
||||
--checkpoint-save-path ./my_checkpoint.pth
|
||||
|
||||
# The checkpoint can be reused for future quantization runs and skip calibration
|
||||
```
|
||||
|
||||
**Export-only Workflow**: If you have a pre-existing fake quantized ModelOpt checkpoint, you can export it directly:
|
||||
|
||||
```python
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.model_loader.loader import get_model_loader
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_path="meta-llama/Llama-3.2-1B-Instruct",
|
||||
quantization="modelopt_fp8",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
load_config = LoadConfig(
|
||||
modelopt_checkpoint_restore_path="./my_checkpoint.pth",
|
||||
modelopt_export_path="./exported_model",
|
||||
)
|
||||
|
||||
# Load and export the model
|
||||
model_loader = get_model_loader(load_config, model_config)
|
||||
model_loader.load_model(model_config=model_config, device_config=DeviceConfig())
|
||||
```
|
||||
|
||||
##### Benefits of ModelOpt
|
||||
|
||||
- **Hardware Optimization**: Specifically optimized for NVIDIA GPU architectures
|
||||
- **Advanced Quantization**: Supports cutting-edge FP8 and FP4 quantization techniques
|
||||
- **Seamless Integration**: Automatic export to HuggingFace format for easy deployment
|
||||
- **Calibration-based**: Uses calibration datasets for optimal quantization quality
|
||||
- **Production Ready**: Enterprise-grade quantization with NVIDIA support
|
||||
|
||||
## Online Quantization
|
||||
|
||||
To enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`:
|
||||
@@ -148,5 +299,6 @@ python3 -m sglang.launch_server \
|
||||
|
||||
- [GPTQModel](https://github.com/ModelCloud/GPTQModel)
|
||||
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
|
||||
- [NVIDIA Model Optimizer (ModelOpt)](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
|
||||
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
|
||||
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
|
||||
|
||||
303
examples/usage/modelopt_quantize_and_export.py
Executable file
303
examples/usage/modelopt_quantize_and_export.py
Executable file
@@ -0,0 +1,303 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example: ModelOpt Quantization and Export with SGLang
|
||||
|
||||
This example demonstrates the streamlined workflow for quantizing a model with
|
||||
ModelOpt and automatically exporting it for deployment with SGLang.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.model_loader.loader import get_model_loader
|
||||
|
||||
|
||||
def _validate_export(export_dir: str) -> bool:
|
||||
"""Validate that an exported model directory contains the expected files."""
|
||||
import glob
|
||||
|
||||
required_files = ["config.json", "tokenizer_config.json"]
|
||||
|
||||
if not os.path.exists(export_dir):
|
||||
return False
|
||||
|
||||
# Check required files
|
||||
for file in required_files:
|
||||
if not os.path.exists(os.path.join(export_dir, file)):
|
||||
return False
|
||||
|
||||
# Check for model files using pattern matching to handle sharded models
|
||||
model_patterns = [
|
||||
"model*.safetensors",
|
||||
"pytorch_model*.bin",
|
||||
]
|
||||
|
||||
has_model_file = False
|
||||
for pattern in model_patterns:
|
||||
matching_files = glob.glob(os.path.join(export_dir, pattern))
|
||||
if matching_files:
|
||||
has_model_file = True
|
||||
break
|
||||
|
||||
return has_model_file
|
||||
|
||||
|
||||
def _get_export_info(export_dir: str) -> Optional[dict]:
|
||||
"""Get information about an exported model."""
|
||||
import json
|
||||
|
||||
if not _validate_export(export_dir):
|
||||
return None
|
||||
|
||||
try:
|
||||
config_path = os.path.join(export_dir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
return {
|
||||
"model_type": config.get("model_type", "unknown"),
|
||||
"architectures": config.get("architectures", []),
|
||||
"quantization_config": config.get("quantization_config", {}),
|
||||
"export_dir": export_dir,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def quantize_and_export_model(
|
||||
model_path: str,
|
||||
export_dir: str,
|
||||
quantization_method: str = "modelopt_fp8",
|
||||
checkpoint_save_path: Optional[str] = None,
|
||||
device: str = "cuda",
|
||||
) -> None:
|
||||
"""
|
||||
Quantize a model with ModelOpt and export it for SGLang deployment.
|
||||
|
||||
Args:
|
||||
model_path: Path to the original model
|
||||
export_dir: Directory to export the quantized model
|
||||
quantization_method: Quantization method ("modelopt_fp8" or "modelopt_fp4")
|
||||
checkpoint_save_path: Optional path to save ModelOpt checkpoint
|
||||
device: Device to use for quantization
|
||||
"""
|
||||
print("🚀 Starting ModelOpt quantization and export workflow")
|
||||
print(f"📥 Input model: {model_path}")
|
||||
print(f"📤 Export directory: {export_dir}")
|
||||
print(f"⚙️ Quantization method: {quantization_method}")
|
||||
|
||||
# Initialize minimal distributed environment for single GPU quantization
|
||||
if not torch.distributed.is_initialized():
|
||||
print("🔧 Initializing distributed environment...")
|
||||
# Set up environment variables for single-process distributed
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355" # Use a different port than tests
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
backend="nccl" if device == "cuda" else "gloo",
|
||||
)
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
# Configure model loading with ModelOpt quantization and export
|
||||
model_config = ModelConfig(
|
||||
model_path=model_path,
|
||||
quantization=quantization_method, # Use unified quantization flag
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
load_config = LoadConfig(
|
||||
modelopt_checkpoint_save_path=checkpoint_save_path,
|
||||
modelopt_export_path=export_dir,
|
||||
)
|
||||
device_config = DeviceConfig(device=device)
|
||||
|
||||
# Load and quantize the model (export happens automatically)
|
||||
print("🔄 Loading and quantizing model...")
|
||||
model_loader = get_model_loader(load_config, model_config)
|
||||
|
||||
try:
|
||||
model_loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
)
|
||||
print("✅ Model quantized successfully!")
|
||||
|
||||
# Validate the export
|
||||
if _validate_export(export_dir):
|
||||
print("✅ Export validation passed!")
|
||||
|
||||
info = _get_export_info(export_dir)
|
||||
if info:
|
||||
print("📋 Model info:")
|
||||
print(f" - Type: {info['model_type']}")
|
||||
print(f" - Architecture: {info['architectures']}")
|
||||
print(f" - Quantization: {info['quantization_config']}")
|
||||
else:
|
||||
print("❌ Export validation failed!")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Quantization failed: {e}")
|
||||
return
|
||||
|
||||
print("\n🎉 Workflow completed successfully!")
|
||||
print(f"📁 Quantized model exported to: {export_dir}")
|
||||
print("\n🚀 To use the exported model:")
|
||||
print(
|
||||
f" python -m sglang.launch_server --model-path {export_dir} --quantization modelopt"
|
||||
)
|
||||
print("\n # Or in Python:")
|
||||
print(" import sglang as sgl")
|
||||
print(f" llm = sgl.Engine(model_path='{export_dir}', quantization='modelopt')")
|
||||
print(" # Note: 'modelopt' auto-detects FP4/FP8 from model config")
|
||||
|
||||
|
||||
def deploy_exported_model(
|
||||
export_dir: str,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 30000,
|
||||
) -> None:
|
||||
"""
|
||||
Deploy an exported ModelOpt quantized model with SGLang.
|
||||
|
||||
Args:
|
||||
export_dir: Directory containing the exported model
|
||||
host: Host to bind the server to
|
||||
port: Port to bind the server to
|
||||
"""
|
||||
print(f"🚀 Deploying exported model from: {export_dir}")
|
||||
|
||||
# Validate export first
|
||||
if not _validate_export(export_dir):
|
||||
print("❌ Invalid export directory!")
|
||||
return
|
||||
|
||||
try:
|
||||
# Launch SGLang engine with the exported model
|
||||
# Using generic "modelopt" for auto-detection of FP4/FP8
|
||||
llm = sgl.Engine(
|
||||
model_path=export_dir,
|
||||
quantization="modelopt",
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
print("✅ Model deployed successfully!")
|
||||
print(f"🌐 Server running at http://{host}:{port}")
|
||||
|
||||
# Example inference
|
||||
prompts = ["Hello, how are you?", "What is the capital of France?"]
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 100}
|
||||
|
||||
print("\n🧪 Running example inference...")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
print(f"Prompt {i+1}: {prompts[i]}")
|
||||
print(f"Output: {output['text']}")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Deployment failed: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ModelOpt Quantization and Export with SGLang",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Quantize and export a model (recommended workflow)
|
||||
python modelopt_quantize_and_export.py quantize \\
|
||||
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\
|
||||
--export-dir ./quantized_model \\
|
||||
--quantization-method modelopt_fp8
|
||||
|
||||
# Deploy a pre-exported model
|
||||
python modelopt_quantize_and_export.py deploy \\
|
||||
--export-dir ./quantized_model
|
||||
""",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Quantize command
|
||||
quantize_parser = subparsers.add_parser(
|
||||
"quantize", help="Quantize and export a model"
|
||||
)
|
||||
quantize_parser.add_argument(
|
||||
"--model-path", required=True, help="Path to the model to quantize"
|
||||
)
|
||||
quantize_parser.add_argument(
|
||||
"--export-dir", required=True, help="Directory to export the quantized model"
|
||||
)
|
||||
quantize_parser.add_argument(
|
||||
"--quantization-method",
|
||||
choices=["modelopt_fp8", "modelopt_fp4"],
|
||||
default="modelopt_fp8",
|
||||
help="Quantization method to use",
|
||||
)
|
||||
quantize_parser.add_argument(
|
||||
"--checkpoint-save-path", help="Optional path to save ModelOpt checkpoint"
|
||||
)
|
||||
quantize_parser.add_argument(
|
||||
"--device", default="cuda", help="Device to use for quantization"
|
||||
)
|
||||
|
||||
# TODO: Quantize-and-serve command removed due to compatibility issues
|
||||
# Use the separate quantize-then-deploy workflow instead
|
||||
|
||||
# Deploy command
|
||||
deploy_parser = subparsers.add_parser("deploy", help="Deploy an exported model")
|
||||
deploy_parser.add_argument(
|
||||
"--export-dir", required=True, help="Directory containing the exported model"
|
||||
)
|
||||
deploy_parser.add_argument(
|
||||
"--host", default="127.0.0.1", help="Host to bind the server to"
|
||||
)
|
||||
deploy_parser.add_argument(
|
||||
"--port", type=int, default=30000, help="Port to bind the server to"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "quantize":
|
||||
quantize_and_export_model(
|
||||
model_path=args.model_path,
|
||||
export_dir=args.export_dir,
|
||||
quantization_method=args.quantization_method,
|
||||
checkpoint_save_path=args.checkpoint_save_path,
|
||||
device=args.device,
|
||||
)
|
||||
elif args.command == "deploy":
|
||||
deploy_exported_model(
|
||||
export_dir=args.export_dir,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -75,12 +75,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
tracing = [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp",
|
||||
"opentelemetry-exporter-otlp-proto-grpc",
|
||||
"opentelemetry-sdk",
|
||||
]
|
||||
modelopt = ["nvidia-modelopt"]
|
||||
test = [
|
||||
"accelerate",
|
||||
"expecttest",
|
||||
@@ -107,6 +102,12 @@ cu130_all = [
|
||||
"sglang[decord]",
|
||||
"sglang[cu130]"
|
||||
]
|
||||
tracing = [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp",
|
||||
"opentelemetry-exporter-otlp-proto-grpc",
|
||||
"opentelemetry-sdk",
|
||||
]
|
||||
|
||||
# To be deprecated in 2 weeks
|
||||
blackwell = ["sglang[dev]"]
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import orjson
|
||||
|
||||
from sglang.srt.configs.modelopt_config import ModelOptConfig
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -51,6 +52,11 @@ class LoadConfig:
|
||||
decryption_key_file: If set, decrypts the output files with a password read
|
||||
from this file (after PBKDF2).
|
||||
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
|
||||
|
||||
# ModelOpt-specific loading options
|
||||
modelopt_checkpoint_restore_path: Optional[str] = None
|
||||
modelopt_checkpoint_save_path: Optional[str] = None
|
||||
modelopt_export_path: Optional[str] = None
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
@@ -64,6 +70,14 @@ class LoadConfig:
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||
|
||||
# ModelOpt-specific loading options
|
||||
modelopt_checkpoint_restore_path: Optional[str] = None
|
||||
modelopt_checkpoint_save_path: Optional[str] = None
|
||||
modelopt_export_path: Optional[str] = None
|
||||
|
||||
# ModelOpt configuration object
|
||||
modelopt_config: Optional[ModelOptConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
@@ -78,6 +92,14 @@ class LoadConfig:
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
# Create ModelOptConfig if not provided
|
||||
if self.modelopt_config is None:
|
||||
self.modelopt_config = ModelOptConfig(
|
||||
checkpoint_restore_path=self.modelopt_checkpoint_restore_path,
|
||||
checkpoint_save_path=self.modelopt_checkpoint_save_path,
|
||||
export_path=self.modelopt_export_path,
|
||||
)
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
if not isinstance(self.load_format, str):
|
||||
return
|
||||
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -89,7 +89,6 @@ class ModelConfig:
|
||||
enable_multimodal: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
modelopt_quant: Optional[Union[str, Dict]] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[
|
||||
@@ -97,15 +96,19 @@ class ModelConfig:
|
||||
] = None, # TODO: remove this, it is not a model config
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
sampling_defaults: str = "openai",
|
||||
quantize_and_serve: bool = False,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.modelopt_quant = modelopt_quant
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
self.sampling_defaults = sampling_defaults
|
||||
self.quantize_and_serve = quantize_and_serve
|
||||
|
||||
# Validate quantize_and_serve configuration
|
||||
self._validate_quantize_and_serve_config()
|
||||
|
||||
# Get hf config
|
||||
self._maybe_pull_model_tokenizer_from_remote()
|
||||
@@ -219,10 +222,10 @@ class ModelConfig:
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
modelopt_quant=server_args.modelopt_quant,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
sampling_defaults=server_args.sampling_defaults,
|
||||
quantize_and_serve=server_args.quantize_and_serve,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -547,6 +550,56 @@ class ModelConfig:
|
||||
# Default to FP8 for backward compatibility
|
||||
return {"quant_method": "modelopt_fp8"}
|
||||
|
||||
def _is_already_quantized(self) -> bool:
|
||||
"""Check if the model is already quantized based on config files."""
|
||||
# Check for HuggingFace quantization config
|
||||
from sglang.srt.utils import has_hf_quant_config
|
||||
|
||||
return has_hf_quant_config(self.model_path)
|
||||
|
||||
def _get_modelopt_quant_type(self) -> str:
|
||||
"""Extract ModelOpt quantization type from unified quantization flag."""
|
||||
if self.quantization == "modelopt_fp8":
|
||||
return "fp8"
|
||||
elif self.quantization == "modelopt_fp4":
|
||||
return "nvfp4"
|
||||
elif self.quantization == "modelopt":
|
||||
# Auto-detect from model config
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
if quant_cfg:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
if "fp4" in quant_method:
|
||||
return "fp4"
|
||||
elif "fp8" in quant_method:
|
||||
return "fp8"
|
||||
# Default to fp8 if can't detect
|
||||
return "fp8"
|
||||
else:
|
||||
return "fp8" # Default fallback
|
||||
|
||||
def _validate_quantize_and_serve_config(self):
|
||||
"""Validate quantize_and_serve configuration."""
|
||||
if not self.quantize_and_serve:
|
||||
return
|
||||
|
||||
# Check if ModelOpt quantization is specified
|
||||
modelopt_quantization_specified = self.quantization in [
|
||||
"modelopt",
|
||||
"modelopt_fp8",
|
||||
"modelopt_fp4",
|
||||
]
|
||||
|
||||
if not modelopt_quantization_specified:
|
||||
raise ValueError("quantize_and_serve requires ModelOpt quantization")
|
||||
|
||||
# quantize_and_serve is disabled due to compatibility issues
|
||||
raise NotImplementedError(
|
||||
"quantize_and_serve functionality is currently disabled due to compatibility issues. "
|
||||
"Please use the separate quantize-then-deploy workflow instead. "
|
||||
"Step 1: Quantize and export model. "
|
||||
"Step 2: Deploy the exported model."
|
||||
)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
|
||||
30
python/sglang/srt/configs/modelopt_config.py
Normal file
30
python/sglang/srt/configs/modelopt_config.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Configuration for NVIDIA ModelOpt quantization integration
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOptConfig:
|
||||
"""Configuration for NVIDIA ModelOpt quantization operations.
|
||||
|
||||
This configuration class holds parameters for ModelOpt quantization,
|
||||
checkpoint management, and model export operations.
|
||||
|
||||
Args:
|
||||
quant: Quantization method/type (e.g., "fp8", "fp4")
|
||||
checkpoint_restore_path: Path to restore ModelOpt checkpoint from
|
||||
checkpoint_save_path: Path to save ModelOpt checkpoint to
|
||||
export_path: Path to export quantized model in HuggingFace format
|
||||
quantize_and_serve: Whether to quantize and serve in one step
|
||||
"""
|
||||
|
||||
quant: Optional[str] = None
|
||||
checkpoint_restore_path: Optional[str] = None
|
||||
checkpoint_save_path: Optional[str] = None
|
||||
export_path: Optional[str] = None
|
||||
quantize_and_serve: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
# Add any validation logic if needed
|
||||
pass
|
||||
@@ -72,6 +72,7 @@ if TYPE_CHECKING:
|
||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"fp8": Fp8Config,
|
||||
"blockwise_int8": BlockInt8Config,
|
||||
"modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
|
||||
"modelopt_fp8": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptFp4Config,
|
||||
"w8a8_int8": W8A8Int8Config,
|
||||
|
||||
@@ -161,6 +161,26 @@ class QuantizationConfig(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _modelopt_override_quantization_method(
|
||||
cls, hf_quant_config, user_quant
|
||||
) -> Optional[str]:
|
||||
"""Shared ModelOpt quantization method override logic."""
|
||||
if hf_quant_config is None:
|
||||
return None
|
||||
|
||||
# Check if this is a ModelOpt config
|
||||
quant_algo = hf_quant_config.get("quant_algo", "").upper()
|
||||
|
||||
# If user specified generic "modelopt", auto-detect the specific method
|
||||
if user_quant == "modelopt":
|
||||
if "FP8" in quant_algo:
|
||||
return "modelopt_fp8"
|
||||
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
|
||||
return "modelopt_fp4"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
|
||||
@@ -111,6 +111,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_config, user_quant):
|
||||
"""Override quantization method based on the model's config."""
|
||||
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "modelopt_fp8"
|
||||
@@ -527,6 +532,11 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_config, user_quant):
|
||||
"""Override quantization method based on the model's config."""
|
||||
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "modelopt_fp4"
|
||||
@@ -608,7 +618,16 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
else:
|
||||
kv_cache_quant_algo = "auto"
|
||||
|
||||
group_size = ModelOptFp4Config.common_group_size(config)
|
||||
group_size = config.get("group_size")
|
||||
# If group_size is not at top level, try to extract from config_groups
|
||||
if group_size is None:
|
||||
config_groups = config.get("config_groups", {})
|
||||
if config_groups:
|
||||
# Get group_size from the first group's weights config
|
||||
first_group = next(iter(config_groups.values()), {})
|
||||
weights_config = first_group.get("weights", {})
|
||||
group_size = weights_config.get("group_size")
|
||||
|
||||
exclude_modules = config.get("ignore", [])
|
||||
else:
|
||||
# Fall back to nested format (hf_quant_config.json - legacy format)
|
||||
@@ -634,15 +653,15 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
)
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
|
||||
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
|
||||
if group_size is None or exclude_modules is None:
|
||||
logger.warning(
|
||||
f"group_size: {group_size},"
|
||||
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
||||
f"exclude_modules: {exclude_modules}"
|
||||
)
|
||||
raise ValueError(
|
||||
"NVFP4 quantization requires group size and "
|
||||
"kv_cache_quant_algo specified in the quantization config"
|
||||
"NVFP4 quantization requires group_size and exclude_modules "
|
||||
"specified in the quantization config"
|
||||
)
|
||||
return cls(
|
||||
is_checkpoint_nvfp4_serialized,
|
||||
|
||||
@@ -828,6 +828,16 @@ class ModelRunner:
|
||||
set_cuda_arch()
|
||||
|
||||
# Prepare the model config
|
||||
from sglang.srt.configs.modelopt_config import ModelOptConfig
|
||||
|
||||
modelopt_config = ModelOptConfig(
|
||||
quant=self.server_args.modelopt_quant,
|
||||
checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
|
||||
checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
|
||||
export_path=self.server_args.modelopt_export_path,
|
||||
quantize_and_serve=self.server_args.quantize_and_serve,
|
||||
)
|
||||
|
||||
self.load_config = LoadConfig(
|
||||
load_format=self.server_args.load_format,
|
||||
download_dir=self.server_args.download_dir,
|
||||
@@ -836,6 +846,7 @@ class ModelRunner:
|
||||
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
modelopt_config=modelopt_config,
|
||||
)
|
||||
if self.device == "cpu":
|
||||
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
||||
|
||||
@@ -538,12 +538,21 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
**model_kwargs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
|
||||
# Handle both legacy modelopt_quant and unified quantization flags
|
||||
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
||||
# Legacy approach
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
rank0_log(f"ModelOpt quantization requested (legacy): {quant_choice_str}")
|
||||
else:
|
||||
# Unified approach - extract quantization type
|
||||
quant_choice_str = model_config._get_modelopt_quant_type()
|
||||
rank0_log(
|
||||
f"ModelOpt quantization requested (unified): {model_config.quantization} -> {quant_choice_str}"
|
||||
)
|
||||
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
if not isinstance(quant_choice_str, str):
|
||||
raise TypeError(
|
||||
f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
|
||||
f"Quantization type must be a string (e.g., 'fp8'), "
|
||||
f"got {type(quant_choice_str)}"
|
||||
)
|
||||
|
||||
@@ -1764,6 +1773,7 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
quant_cfg,
|
||||
quantized_ckpt_restore_path: str | None = None,
|
||||
quantized_ckpt_save_path: str | None = None,
|
||||
export_path: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set up ModelOpt quantization for the given model.
|
||||
@@ -1774,6 +1784,7 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
quant_cfg: The quantization configuration
|
||||
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
|
||||
quantized_ckpt_save_path: Path to save quantized checkpoint to
|
||||
export_path: Path to export the quantized model in HuggingFace format
|
||||
|
||||
Raises:
|
||||
ImportError: If ModelOpt is not available
|
||||
@@ -1798,6 +1809,9 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
rank0_log(
|
||||
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
||||
)
|
||||
|
||||
# Export model if path provided (even when restoring from checkpoint)
|
||||
self._maybe_export_modelopt(model, export_path)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
@@ -1844,9 +1858,75 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
|
||||
)
|
||||
|
||||
# Export model if path provided
|
||||
self._maybe_export_modelopt(model, export_path)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
|
||||
|
||||
def _maybe_export_modelopt(self, model, export_path: str | None) -> None:
|
||||
"""Export model to HuggingFace format if export_path is provided."""
|
||||
if export_path:
|
||||
try:
|
||||
# Get the original model path from the model config
|
||||
original_model_path = getattr(self, "_original_model_path", None)
|
||||
self._export_modelopt_checkpoint(
|
||||
model, export_path, original_model_path
|
||||
)
|
||||
rank0_log(
|
||||
f"Quantized model exported to HuggingFace format at {export_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
rank0_log(
|
||||
f"Warning: Failed to export quantized model to {export_path}: {e}"
|
||||
)
|
||||
|
||||
def _export_modelopt_checkpoint(
|
||||
self,
|
||||
model,
|
||||
export_path: str,
|
||||
model_path: str = None,
|
||||
trust_remote_code: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Export the quantized model to HuggingFace format using ModelOpt export API.
|
||||
|
||||
Args:
|
||||
model: The quantized model to export
|
||||
export_path: Directory path to export the model to
|
||||
model_path: Path to the original model (for tokenizer export)
|
||||
trust_remote_code: Whether to trust remote code for tokenizer loading
|
||||
|
||||
Raises:
|
||||
ImportError: If ModelOpt export functionality is not available
|
||||
Exception: If export fails
|
||||
"""
|
||||
try:
|
||||
from modelopt.torch.export import export_hf_checkpoint
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ModelOpt export functionality is not available. "
|
||||
"Please ensure you have the latest version of modelopt installed."
|
||||
) from e
|
||||
|
||||
# Create export directory if it doesn't exist
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
|
||||
# Export the quantized model
|
||||
export_hf_checkpoint(model, export_dir=export_path)
|
||||
|
||||
# Export the tokenizer if model_path is provided
|
||||
if model_path:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, trust_remote_code=trust_remote_code
|
||||
)
|
||||
tokenizer.save_pretrained(export_path)
|
||||
rank0_log(f"Tokenizer exported to {export_path}")
|
||||
except Exception as e:
|
||||
rank0_log(f"Warning: Failed to export tokenizer: {e}")
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
@@ -1856,28 +1936,52 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
|
||||
logger.info("ModelOptModelLoader: Loading base model...")
|
||||
|
||||
# Use shared method from parent class to load base model
|
||||
# Store the original model path for tokenizer export
|
||||
self._original_model_path = model_config.model_path
|
||||
|
||||
# Check if model is already quantized
|
||||
if model_config._is_already_quantized():
|
||||
logger.info("Model is already quantized, loading directly...")
|
||||
# Use default loading for pre-quantized models
|
||||
return super().load_model(
|
||||
model_config=model_config, device_config=device_config
|
||||
)
|
||||
|
||||
# TODO: Quantize-and-serve mode has been disabled at the ModelConfig level
|
||||
# All quantization now uses the standard workflow (quantize + export/save)
|
||||
logger.info("Standard quantization mode: Will quantize and export/save")
|
||||
return self._standard_quantization_workflow(model_config, device_config)
|
||||
|
||||
def _standard_quantization_workflow(
|
||||
self, model_config: ModelConfig, device_config: DeviceConfig
|
||||
) -> nn.Module:
|
||||
"""Standard quantization workflow: quantize, save checkpoint, export, then return model."""
|
||||
# Use shared method from parent class to load base model for quantization
|
||||
model = self._load_modelopt_base_model(model_config)
|
||||
|
||||
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
|
||||
# Import ModelOpt modules
|
||||
try:
|
||||
import modelopt.torch.quantization as mtq
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"NVIDIA Model Optimizer (modelopt) library not found. "
|
||||
"Please install it to use 'modelopt_quant' feature."
|
||||
"Please install it to use ModelOpt quantization."
|
||||
)
|
||||
raise
|
||||
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
# Handle both old modelopt_quant and new unified quantization flags
|
||||
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
||||
# Legacy modelopt_quant flag
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
else:
|
||||
# Unified quantization flag - extract the type (fp8/fp4)
|
||||
quant_choice_str = model_config._get_modelopt_quant_type()
|
||||
|
||||
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
||||
if not quant_cfg_name:
|
||||
raise ValueError(
|
||||
f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
|
||||
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
|
||||
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
|
||||
"attribute names of config objects in modelopt.torch.quantization."
|
||||
f"Invalid quantization choice: '{quant_choice_str}'. "
|
||||
f"Available choices: {list(QUANT_CFG_CHOICES.keys())}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1885,20 +1989,27 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
quant_cfg = getattr(mtq, quant_cfg_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
f"ModelOpt quantization config attribute '{quant_cfg_name}' "
|
||||
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
|
||||
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
|
||||
f"ModelOpt quantization config '{quant_cfg_name}' not found. "
|
||||
"Please verify the ModelOpt library installation."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
|
||||
f"Quantizing model with ModelOpt using config: mtq.{quant_cfg_name}"
|
||||
)
|
||||
|
||||
quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path
|
||||
quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path
|
||||
# Get ModelOpt configuration from LoadConfig
|
||||
modelopt_config = self.load_config.modelopt_config
|
||||
quantized_ckpt_restore_path = (
|
||||
modelopt_config.checkpoint_restore_path if modelopt_config else None
|
||||
)
|
||||
quantized_ckpt_save_path = (
|
||||
modelopt_config.checkpoint_save_path if modelopt_config else None
|
||||
)
|
||||
export_path = modelopt_config.export_path if modelopt_config else None
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_config.model_path, use_fast=True
|
||||
)
|
||||
|
||||
try:
|
||||
self._setup_modelopt_quantization(
|
||||
model,
|
||||
@@ -1906,6 +2017,7 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
quant_cfg,
|
||||
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
|
||||
quantized_ckpt_save_path=quantized_ckpt_save_path,
|
||||
export_path=export_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"ModelOpt quantization failed: {e}")
|
||||
@@ -1919,12 +2031,27 @@ def get_model_loader(
|
||||
) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
if model_config and (
|
||||
(hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant)
|
||||
or model_config.quantization in ["modelopt_fp8", "modelopt_fp4", "modelopt"]
|
||||
):
|
||||
logger.info("Using ModelOptModelLoader due to ModelOpt quantization config.")
|
||||
return ModelOptModelLoader(load_config)
|
||||
|
||||
# Use ModelOptModelLoader for unified quantization flags
|
||||
if (
|
||||
model_config
|
||||
and hasattr(model_config, "modelopt_quant")
|
||||
and model_config.modelopt_quant
|
||||
and hasattr(model_config, "quantization")
|
||||
and model_config.quantization in ["modelopt_fp8", "modelopt_fp4"]
|
||||
):
|
||||
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
|
||||
if model_config._is_already_quantized():
|
||||
logger.info(
|
||||
f"Using ModelOptModelLoader for pre-quantized model: {model_config.quantization}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using ModelOptModelLoader for quantization: {model_config.quantization}"
|
||||
)
|
||||
return ModelOptModelLoader(load_config)
|
||||
|
||||
if isinstance(load_config.load_format, type):
|
||||
|
||||
@@ -83,6 +83,7 @@ QUANTIZATION_CHOICES = [
|
||||
"bitsandbytes",
|
||||
"gguf",
|
||||
"modelopt",
|
||||
"modelopt_fp8",
|
||||
"modelopt_fp4",
|
||||
"petit_nvfp4",
|
||||
"w8a8_int8",
|
||||
@@ -192,6 +193,8 @@ class ServerArgs:
|
||||
modelopt_quant: Optional[Union[str, Dict]] = None
|
||||
modelopt_checkpoint_restore_path: Optional[str] = None
|
||||
modelopt_checkpoint_save_path: Optional[str] = None
|
||||
modelopt_export_path: Optional[str] = None
|
||||
quantize_and_serve: bool = False
|
||||
context_length: Optional[int] = None
|
||||
is_embedding: bool = False
|
||||
enable_multimodal: Optional[bool] = None
|
||||
@@ -1743,6 +1746,22 @@ class ServerArgs:
|
||||
help="Path to save the ModelOpt quantized checkpoint after quantization. "
|
||||
"This allows reusing the quantized model in future runs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modelopt-export-path",
|
||||
type=str,
|
||||
default=ServerArgs.modelopt_export_path,
|
||||
help="Path to export the quantized model in HuggingFace format after ModelOpt quantization. "
|
||||
"The exported model can then be used directly with SGLang for inference. "
|
||||
"If not provided, the model will not be exported.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize-and-serve",
|
||||
action="store_true",
|
||||
default=ServerArgs.quantize_and_serve,
|
||||
help="Quantize the model with ModelOpt and immediately serve it without exporting. "
|
||||
"This is useful for development and prototyping. For production, it's recommended "
|
||||
"to use separate quantization and deployment steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
|
||||
@@ -2411,6 +2411,29 @@ def retry(
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def has_hf_quant_config(model_path: str) -> bool:
|
||||
"""Check if the model path contains hf_quant_config.json file.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model, can be local path or remote URL.
|
||||
|
||||
Returns:
|
||||
True if hf_quant_config.json exists, False otherwise.
|
||||
"""
|
||||
if is_remote_url(model_path):
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hf_api = HfApi()
|
||||
return hf_api.file_exists(model_path, "hf_quant_config.json")
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
import os
|
||||
|
||||
return os.path.exists(os.path.join(model_path, "hf_quant_config.json"))
|
||||
|
||||
|
||||
def flatten_nested_list(nested_list):
|
||||
if isinstance(nested_list, list):
|
||||
return [
|
||||
|
||||
@@ -135,6 +135,8 @@ suites = {
|
||||
TestFile("test_vision_chunked_prefill.py", 175),
|
||||
TestFile("test_vision_openai_server_a.py", 918),
|
||||
TestFile("test_vlm_input_format.py", 300),
|
||||
TestFile("test_modelopt_loader.py", 30),
|
||||
TestFile("test_modelopt_export.py", 30),
|
||||
],
|
||||
"per-commit-2-gpu": [
|
||||
TestFile("ep/test_moe_ep.py", 140),
|
||||
|
||||
353
test/srt/test_modelopt_export.py
Normal file
353
test/srt/test_modelopt_export.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Unit tests for ModelOpt export functionality in SGLang.
|
||||
|
||||
These tests verify the integration of ModelOpt export API with SGLang's model loading
|
||||
and quantization workflow.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.model_loader.loader import ModelOptModelLoader
|
||||
|
||||
# Note: PYTHONPATH=python should be set when running tests
|
||||
|
||||
# Check if modelopt is available
|
||||
try:
|
||||
import modelopt
|
||||
|
||||
MODELOPT_AVAILABLE = True
|
||||
except ImportError:
|
||||
MODELOPT_AVAILABLE = False
|
||||
|
||||
|
||||
class TestModelOptExport(unittest.TestCase):
|
||||
"""Test suite for ModelOpt export functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# Mock distributed functionality to avoid initialization errors
|
||||
self.mock_tp_rank = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
)
|
||||
self.mock_tp_rank.start()
|
||||
|
||||
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
|
||||
self.mock_rank0_log.start()
|
||||
|
||||
# Mock logger to avoid issues
|
||||
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
|
||||
self.mock_logger.start()
|
||||
|
||||
# Mock all distributed functions that might be called
|
||||
self.mock_get_tp_group = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tp_group"
|
||||
)
|
||||
self.mock_get_tp_group.start()
|
||||
|
||||
# Mock model parallel initialization check
|
||||
self.mock_mp_is_initialized = patch(
|
||||
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
|
||||
return_value=True,
|
||||
)
|
||||
self.mock_mp_is_initialized.start()
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.export_dir = os.path.join(self.temp_dir, "exported_model")
|
||||
self.checkpoint_dir = os.path.join(self.temp_dir, "checkpoint")
|
||||
|
||||
# Mock model
|
||||
self.mock_model = Mock(spec=torch.nn.Module)
|
||||
self.mock_model.device = torch.device("cuda:0")
|
||||
|
||||
# Mock tokenizer
|
||||
self.mock_tokenizer = Mock()
|
||||
|
||||
# Mock quantization config
|
||||
self.mock_quant_cfg = Mock()
|
||||
|
||||
# Create ModelOptModelLoader instance
|
||||
self.load_config = LoadConfig()
|
||||
self.model_loader = ModelOptModelLoader(self.load_config)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
# Stop mocks
|
||||
self.mock_tp_rank.stop()
|
||||
self.mock_rank0_log.stop()
|
||||
self.mock_logger.stop()
|
||||
self.mock_get_tp_group.stop()
|
||||
self.mock_mp_is_initialized.stop()
|
||||
|
||||
def _create_mock_export_files(self, export_dir: str):
|
||||
"""Create mock export files for testing validation."""
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
|
||||
# Create config.json
|
||||
config = {
|
||||
"model_type": "test_model",
|
||||
"architectures": ["TestModel"],
|
||||
"quantization_config": {
|
||||
"quant_method": "modelopt",
|
||||
"bits": 8,
|
||||
},
|
||||
}
|
||||
with open(os.path.join(export_dir, "config.json"), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
# Create tokenizer_config.json
|
||||
tokenizer_config = {"tokenizer_class": "TestTokenizer"}
|
||||
with open(os.path.join(export_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(tokenizer_config, f)
|
||||
|
||||
# Create model file
|
||||
with open(os.path.join(export_dir, "model.safetensors"), "w") as f:
|
||||
f.write("mock_model_data")
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
@patch("sglang.srt.model_loader.loader.os.makedirs")
|
||||
@patch("modelopt.torch.export.export_hf_checkpoint")
|
||||
def test_export_modelopt_checkpoint_success(self, mock_export, mock_makedirs):
|
||||
"""Test successful model export."""
|
||||
# Arrange
|
||||
mock_export.return_value = None
|
||||
mock_makedirs.return_value = None
|
||||
|
||||
# Act
|
||||
self.model_loader._export_modelopt_checkpoint(self.mock_model, self.export_dir)
|
||||
|
||||
# Assert
|
||||
mock_makedirs.assert_called_once_with(self.export_dir, exist_ok=True)
|
||||
mock_export.assert_called_once_with(self.mock_model, export_dir=self.export_dir)
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
@patch("modelopt.torch.opt.restore")
|
||||
@patch("modelopt.torch.quantization.utils.is_quantized")
|
||||
def test_setup_quantization_with_export_from_checkpoint(
|
||||
self, mock_is_quantized, mock_restore
|
||||
):
|
||||
"""Test export functionality when restoring from checkpoint."""
|
||||
# Arrange
|
||||
mock_is_quantized.return_value = False
|
||||
mock_restore.return_value = None
|
||||
|
||||
with patch.object(
|
||||
self.model_loader, "_export_modelopt_checkpoint"
|
||||
) as mock_export:
|
||||
# Act
|
||||
self.model_loader._setup_modelopt_quantization(
|
||||
self.mock_model,
|
||||
self.mock_tokenizer,
|
||||
self.mock_quant_cfg,
|
||||
quantized_ckpt_restore_path=self.checkpoint_dir,
|
||||
export_path=self.export_dir,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_restore.assert_called_once_with(self.mock_model, self.checkpoint_dir)
|
||||
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
@patch("modelopt.torch.quantization.quantize")
|
||||
@patch("modelopt.torch.quantization.print_quant_summary")
|
||||
@patch("modelopt.torch.quantization.utils.is_quantized")
|
||||
@patch("modelopt.torch.utils.dataset_utils.get_dataset_dataloader")
|
||||
@patch("modelopt.torch.utils.dataset_utils.create_forward_loop")
|
||||
def test_setup_quantization_with_export_after_calibration(
|
||||
self,
|
||||
mock_create_loop,
|
||||
mock_get_dataloader,
|
||||
mock_is_quantized,
|
||||
mock_print_summary,
|
||||
mock_quantize,
|
||||
):
|
||||
"""Test export functionality after calibration-based quantization."""
|
||||
# Arrange
|
||||
mock_is_quantized.return_value = False
|
||||
mock_dataloader = Mock()
|
||||
mock_get_dataloader.return_value = mock_dataloader
|
||||
mock_calibrate_loop = Mock()
|
||||
mock_create_loop.return_value = mock_calibrate_loop
|
||||
mock_quantize.return_value = None
|
||||
mock_print_summary.return_value = None
|
||||
|
||||
with patch.object(
|
||||
self.model_loader, "_export_modelopt_checkpoint"
|
||||
) as mock_export:
|
||||
# Act
|
||||
self.model_loader._setup_modelopt_quantization(
|
||||
self.mock_model,
|
||||
self.mock_tokenizer,
|
||||
self.mock_quant_cfg,
|
||||
export_path=self.export_dir,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_quantize.assert_called_once_with(
|
||||
self.mock_model, self.mock_quant_cfg, forward_loop=mock_calibrate_loop
|
||||
)
|
||||
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
def test_setup_quantization_without_export(self):
|
||||
"""Test quantization setup without export path specified."""
|
||||
with patch("modelopt.torch.quantization.utils.is_quantized", return_value=True):
|
||||
# Act
|
||||
with patch.object(
|
||||
self.model_loader, "_export_modelopt_checkpoint"
|
||||
) as mock_export:
|
||||
self.model_loader._setup_modelopt_quantization(
|
||||
self.mock_model,
|
||||
self.mock_tokenizer,
|
||||
self.mock_quant_cfg,
|
||||
export_path=None, # No export path
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_export.assert_not_called()
|
||||
|
||||
def test_quantize_and_serve_config_validation(self):
|
||||
"""Test that quantize_and_serve is properly disabled."""
|
||||
# Test that quantize-and-serve mode raises NotImplementedError
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
quantization="modelopt_fp8",
|
||||
quantize_and_serve=True,
|
||||
)
|
||||
|
||||
# Verify the error message contains helpful instructions
|
||||
error_msg = str(context.exception)
|
||||
self.assertIn("disabled due to compatibility issues", error_msg)
|
||||
self.assertIn("separate quantize-then-deploy workflow", error_msg)
|
||||
|
||||
# Test invalid configuration - no quantization
|
||||
with self.assertRaises(ValueError) as context:
|
||||
ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
quantize_and_serve=True,
|
||||
)
|
||||
self.assertIn("requires ModelOpt quantization", str(context.exception))
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
def test_standard_workflow_selection(self):
|
||||
"""Test that standard workflow is selected by default."""
|
||||
with patch(
|
||||
"modelopt.torch.quantization.utils.is_quantized", return_value=False
|
||||
):
|
||||
with patch.object(
|
||||
self.model_loader, "_standard_quantization_workflow"
|
||||
) as mock_standard:
|
||||
with patch.object(self.model_loader, "_load_modelopt_base_model"):
|
||||
mock_standard.return_value = Mock()
|
||||
|
||||
# Create model config without quantize_and_serve
|
||||
model_config = ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
quantization="modelopt_fp8",
|
||||
quantize_and_serve=False,
|
||||
)
|
||||
device_config = DeviceConfig()
|
||||
|
||||
# Act
|
||||
self.model_loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_standard.assert_called_once_with(model_config, device_config)
|
||||
|
||||
def _get_export_info(self, export_dir: str) -> dict:
|
||||
"""Get information about an exported model."""
|
||||
if not self._validate_export(export_dir):
|
||||
return None
|
||||
|
||||
try:
|
||||
config_path = os.path.join(export_dir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
return {
|
||||
"model_type": config.get("model_type", "unknown"),
|
||||
"architectures": config.get("architectures", []),
|
||||
"quantization_config": config.get("quantization_config", {}),
|
||||
"export_dir": export_dir,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
class TestModelOptExportIntegration(unittest.TestCase):
|
||||
"""Integration tests for ModelOpt export with full model loading workflow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up integration test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.export_dir = os.path.join(self.temp_dir, "exported_model")
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up integration test fixtures."""
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.get_model_architecture")
|
||||
@patch("transformers.AutoTokenizer.from_pretrained")
|
||||
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
||||
def test_full_workflow_with_export(self, mock_model, mock_tokenizer, mock_arch):
|
||||
"""Test the complete workflow from model config to export."""
|
||||
# Arrange
|
||||
mock_arch.return_value = ("TestModel", "TestConfig")
|
||||
mock_tokenizer.return_value = Mock()
|
||||
mock_model.return_value = Mock(spec=torch.nn.Module)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
modelopt_quant="fp8",
|
||||
modelopt_export_path=self.export_dir,
|
||||
)
|
||||
|
||||
load_config = LoadConfig()
|
||||
device_config = DeviceConfig()
|
||||
|
||||
# Mock the quantization and export process
|
||||
with patch.object(
|
||||
ModelOptModelLoader, "_setup_modelopt_quantization"
|
||||
) as mock_setup:
|
||||
with patch.object(
|
||||
ModelOptModelLoader, "_load_modelopt_base_model"
|
||||
) as mock_load_base:
|
||||
mock_load_base.return_value = mock_model.return_value
|
||||
|
||||
# Act
|
||||
model_loader = ModelOptModelLoader(load_config)
|
||||
result = model_loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
)
|
||||
|
||||
# Assert
|
||||
self.assertIsNotNone(result)
|
||||
mock_setup.assert_called_once()
|
||||
# Verify export_path was passed to setup
|
||||
args, kwargs = mock_setup.call_args
|
||||
self.assertEqual(kwargs.get("export_path"), self.export_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -12,8 +12,17 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
# Add the sglang path for testing
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
|
||||
# Note: PYTHONPATH=python should be set when running tests
|
||||
|
||||
# Constants for calibration parameters to avoid hard-coded values
|
||||
CALIBRATION_BATCH_SIZE = 36
|
||||
CALIBRATION_NUM_SAMPLES = 512
|
||||
DEFAULT_DEVICE = "cuda:0"
|
||||
|
||||
# Constants for calibration parameters to avoid hard-coded values
|
||||
CALIBRATION_BATCH_SIZE = 36
|
||||
CALIBRATION_NUM_SAMPLES = 512
|
||||
DEFAULT_DEVICE = "cuda:0"
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
@@ -28,18 +37,63 @@ class TestModelOptModelLoader(CustomTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# Mock distributed functionality to avoid initialization errors
|
||||
self.mock_tp_rank = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
)
|
||||
self.mock_tp_rank.start()
|
||||
|
||||
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
|
||||
self.mock_rank0_log.start()
|
||||
|
||||
# Mock logger to avoid issues
|
||||
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
|
||||
self.mock_logger.start()
|
||||
|
||||
# Mock all distributed functions that might be called
|
||||
self.mock_get_tp_group = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tp_group"
|
||||
)
|
||||
self.mock_get_tp_group.start()
|
||||
|
||||
# Mock model parallel initialization check
|
||||
self.mock_mp_is_initialized = patch(
|
||||
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
|
||||
return_value=True,
|
||||
)
|
||||
self.mock_mp_is_initialized.start()
|
||||
|
||||
self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
self.load_config = LoadConfig()
|
||||
self.device_config = DeviceConfig(device="cuda")
|
||||
|
||||
# Create a basic model config with modelopt_quant
|
||||
# Create a basic model config with unified quantization flag
|
||||
self.model_config = ModelConfig(
|
||||
model_path=self.model_path, modelopt_quant="fp8"
|
||||
model_path=self.model_path,
|
||||
quantization="modelopt_fp8", # Use unified quantization approach
|
||||
)
|
||||
|
||||
# Also create a unified quantization config for new tests
|
||||
self.unified_model_config = ModelConfig(
|
||||
model_path=self.model_path, quantization="modelopt_fp8"
|
||||
)
|
||||
|
||||
# Mock base model
|
||||
self.mock_base_model = MagicMock(spec=nn.Module)
|
||||
self.mock_base_model.eval.return_value = self.mock_base_model
|
||||
self.mock_base_model.device = (
|
||||
DEFAULT_DEVICE # Add device attribute for calibration tests
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
# Stop mocks
|
||||
self.mock_tp_rank.stop()
|
||||
self.mock_rank0_log.stop()
|
||||
self.mock_logger.stop()
|
||||
self.mock_get_tp_group.stop()
|
||||
self.mock_mp_is_initialized.stop()
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
@@ -66,7 +120,7 @@ class TestModelOptModelLoader(CustomTestCase):
|
||||
model = self.mock_base_model
|
||||
|
||||
# Simulate the quantization config lookup
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
quant_choice_str = model_config._get_modelopt_quant_type()
|
||||
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
||||
|
||||
if not quant_cfg_name:
|
||||
@@ -123,6 +177,305 @@ class TestModelOptModelLoader(CustomTestCase):
|
||||
# Verify we get back the expected model
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_missing_modelopt_import(self, mock_logger):
|
||||
"""Test error handling when modelopt library is not available."""
|
||||
|
||||
loader = ModelOptModelLoader(self.load_config)
|
||||
|
||||
# Mock the base model loader method
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
# Simulate missing modelopt by making import fail
|
||||
original_import = __import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name.startswith("modelopt"):
|
||||
raise ImportError("No module named 'modelopt'")
|
||||
# Return default import behavior for other modules
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=mock_import):
|
||||
# Expect ImportError to be raised and logged
|
||||
with self.assertRaises(ImportError):
|
||||
loader.load_model(
|
||||
model_config=self.model_config, device_config=self.device_config
|
||||
)
|
||||
|
||||
# Verify error logging
|
||||
mock_logger.error.assert_called_with(
|
||||
"NVIDIA Model Optimizer (modelopt) library not found. "
|
||||
"Please install it to use ModelOpt quantization."
|
||||
)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_calibration_workflow_integration(self, mock_logger, mock_auto_tokenizer):
|
||||
"""Test end-to-end calibration workflow integration."""
|
||||
|
||||
loader = ModelOptModelLoader(self.load_config)
|
||||
|
||||
# Mock tokenizer
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.padding_side = "right"
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Mock modelopt modules
|
||||
mock_mtq = MagicMock()
|
||||
mock_mto = MagicMock()
|
||||
mock_dataset_utils = MagicMock()
|
||||
|
||||
# Configure quantization config
|
||||
mock_fp8_cfg = MagicMock()
|
||||
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||
|
||||
# Configure dataset utilities
|
||||
mock_calib_dataloader = MagicMock()
|
||||
mock_calibrate_loop = MagicMock()
|
||||
mock_dataset_utils.get_dataset_dataloader.return_value = mock_calib_dataloader
|
||||
mock_dataset_utils.create_forward_loop.return_value = mock_calibrate_loop
|
||||
|
||||
# Configure model as not quantized initially
|
||||
mock_is_quantized = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"modelopt": MagicMock(),
|
||||
"modelopt.torch": MagicMock(),
|
||||
"modelopt.torch.opt": mock_mto,
|
||||
"modelopt.torch.quantization": mock_mtq,
|
||||
"modelopt.torch.quantization.utils": MagicMock(
|
||||
is_quantized=mock_is_quantized
|
||||
),
|
||||
"modelopt.torch.utils": MagicMock(),
|
||||
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
|
||||
},
|
||||
):
|
||||
# Execute the load_model method to test the full workflow
|
||||
result_model = loader.load_model(
|
||||
model_config=self.model_config, device_config=self.device_config
|
||||
)
|
||||
|
||||
# Verify the model loading was successful
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
# Verify key calibration components were used
|
||||
# Note: We can't easily verify the exact calls due to dynamic imports,
|
||||
# but we can verify the workflow completed successfully
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_quantized_checkpoint_restore(self, mock_logger, mock_auto_tokenizer):
|
||||
"""Test restoring from a quantized checkpoint."""
|
||||
|
||||
# Create model config with checkpoint restore path
|
||||
config_with_restore = ModelConfig(
|
||||
model_path=self.model_path,
|
||||
quantization="modelopt_fp8",
|
||||
)
|
||||
|
||||
# Create load config with checkpoint restore path
|
||||
load_config_with_restore = LoadConfig(
|
||||
modelopt_checkpoint_restore_path="/path/to/quantized/checkpoint"
|
||||
)
|
||||
|
||||
loader = ModelOptModelLoader(load_config_with_restore)
|
||||
|
||||
# Mock tokenizer
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Mock modelopt modules
|
||||
mock_mtq = MagicMock()
|
||||
mock_mto = MagicMock()
|
||||
|
||||
# Configure quantization config
|
||||
mock_fp8_cfg = MagicMock()
|
||||
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||
|
||||
# Configure model as not quantized initially
|
||||
mock_is_quantized = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"modelopt": MagicMock(),
|
||||
"modelopt.torch": MagicMock(),
|
||||
"modelopt.torch.opt": mock_mto,
|
||||
"modelopt.torch.quantization": mock_mtq,
|
||||
"modelopt.torch.quantization.utils": MagicMock(
|
||||
is_quantized=mock_is_quantized
|
||||
),
|
||||
},
|
||||
):
|
||||
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
|
||||
# Mock the _setup_modelopt_quantization to simulate checkpoint restore
|
||||
def mock_setup_quantization(
|
||||
model,
|
||||
tokenizer,
|
||||
quant_cfg,
|
||||
quantized_ckpt_restore_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
if quantized_ckpt_restore_path:
|
||||
mock_mto.restore(model, quantized_ckpt_restore_path)
|
||||
print(
|
||||
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
||||
)
|
||||
return
|
||||
|
||||
mock_setup.side_effect = mock_setup_quantization
|
||||
|
||||
# Execute the load_model method
|
||||
result_model = loader.load_model(
|
||||
model_config=config_with_restore,
|
||||
device_config=self.device_config,
|
||||
)
|
||||
|
||||
# Verify the setup was called with restore path
|
||||
mock_setup.assert_called_once()
|
||||
call_args = mock_setup.call_args
|
||||
# Check that the restore path was passed correctly
|
||||
self.assertIn("quantized_ckpt_restore_path", call_args[1])
|
||||
self.assertEqual(
|
||||
call_args[1]["quantized_ckpt_restore_path"],
|
||||
"/path/to/quantized/checkpoint",
|
||||
)
|
||||
|
||||
# Verify restore was called
|
||||
mock_mto.restore.assert_called_once_with(
|
||||
self.mock_base_model, "/path/to/quantized/checkpoint"
|
||||
)
|
||||
|
||||
# Verify we get the expected model back
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_quantized_checkpoint_save(self, mock_logger, mock_auto_tokenizer):
|
||||
"""Test saving quantized checkpoint after calibration."""
|
||||
|
||||
# Create model config with checkpoint save path
|
||||
config_with_save = ModelConfig(
|
||||
model_path=self.model_path,
|
||||
quantization="modelopt_fp8",
|
||||
)
|
||||
|
||||
# Create load config with checkpoint save path
|
||||
load_config_with_save = LoadConfig(
|
||||
modelopt_checkpoint_save_path="/path/to/save/checkpoint"
|
||||
)
|
||||
|
||||
loader = ModelOptModelLoader(load_config_with_save)
|
||||
|
||||
# Mock tokenizer
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Mock modelopt modules
|
||||
mock_mtq = MagicMock()
|
||||
mock_mto = MagicMock()
|
||||
mock_dataset_utils = MagicMock()
|
||||
|
||||
# Configure quantization config
|
||||
mock_fp8_cfg = MagicMock()
|
||||
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||
|
||||
# Configure model as not quantized initially
|
||||
mock_is_quantized = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"modelopt": MagicMock(),
|
||||
"modelopt.torch": MagicMock(),
|
||||
"modelopt.torch.opt": mock_mto,
|
||||
"modelopt.torch.quantization": mock_mtq,
|
||||
"modelopt.torch.quantization.utils": MagicMock(
|
||||
is_quantized=mock_is_quantized
|
||||
),
|
||||
"modelopt.torch.utils": MagicMock(),
|
||||
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
|
||||
},
|
||||
):
|
||||
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
|
||||
# Mock the _setup_modelopt_quantization to simulate checkpoint save
|
||||
def mock_setup_quantization(
|
||||
model,
|
||||
tokenizer,
|
||||
quant_cfg,
|
||||
quantized_ckpt_save_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Simulate calibration and quantization
|
||||
mock_mtq.quantize(model, quant_cfg, forward_loop=MagicMock())
|
||||
mock_mtq.print_quant_summary(model)
|
||||
|
||||
# Save checkpoint if path provided
|
||||
if quantized_ckpt_save_path:
|
||||
mock_mto.save(model, quantized_ckpt_save_path)
|
||||
print(
|
||||
f"Quantized model saved to {quantized_ckpt_save_path}"
|
||||
)
|
||||
|
||||
mock_setup.side_effect = mock_setup_quantization
|
||||
|
||||
# Execute the load_model method
|
||||
result_model = loader.load_model(
|
||||
model_config=config_with_save, device_config=self.device_config
|
||||
)
|
||||
|
||||
# Verify the setup was called with save path
|
||||
mock_setup.assert_called_once()
|
||||
call_args = mock_setup.call_args
|
||||
# Check that the save path was passed correctly
|
||||
self.assertIn("quantized_ckpt_save_path", call_args[1])
|
||||
self.assertEqual(
|
||||
call_args[1]["quantized_ckpt_save_path"],
|
||||
"/path/to/save/checkpoint",
|
||||
)
|
||||
|
||||
# Verify save was called
|
||||
mock_mto.save.assert_called_once_with(
|
||||
self.mock_base_model, "/path/to/save/checkpoint"
|
||||
)
|
||||
|
||||
# Verify we get the expected model back
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
def test_unified_quantization_flag_support(self):
|
||||
"""Test that ModelOptModelLoader supports unified quantization flags."""
|
||||
# Test modelopt_fp8
|
||||
config_fp8 = ModelConfig(
|
||||
model_path=self.model_path, quantization="modelopt_fp8"
|
||||
)
|
||||
self.assertEqual(config_fp8._get_modelopt_quant_type(), "fp8")
|
||||
|
||||
# Test modelopt_fp4
|
||||
config_fp4 = ModelConfig(
|
||||
model_path=self.model_path, quantization="modelopt_fp4"
|
||||
)
|
||||
self.assertEqual(config_fp4._get_modelopt_quant_type(), "nvfp4")
|
||||
|
||||
# Test auto-detection
|
||||
config_auto = ModelConfig(model_path=self.model_path, quantization="modelopt")
|
||||
# Should default to fp8 when no config is detected
|
||||
self.assertEqual(config_auto._get_modelopt_quant_type(), "fp8")
|
||||
|
||||
|
||||
class TestModelOptLoaderIntegration(CustomTestCase):
|
||||
"""Integration tests for ModelOptModelLoader with Engine API."""
|
||||
|
||||
Reference in New Issue
Block a user