Enable native ModelOpt quantization support (3/3) (#10154)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
@@ -135,6 +135,8 @@ suites = {
|
||||
TestFile("test_vision_chunked_prefill.py", 175),
|
||||
TestFile("test_vision_openai_server_a.py", 918),
|
||||
TestFile("test_vlm_input_format.py", 300),
|
||||
TestFile("test_modelopt_loader.py", 30),
|
||||
TestFile("test_modelopt_export.py", 30),
|
||||
],
|
||||
"per-commit-2-gpu": [
|
||||
TestFile("ep/test_moe_ep.py", 140),
|
||||
|
||||
353
test/srt/test_modelopt_export.py
Normal file
353
test/srt/test_modelopt_export.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Unit tests for ModelOpt export functionality in SGLang.
|
||||
|
||||
These tests verify the integration of ModelOpt export API with SGLang's model loading
|
||||
and quantization workflow.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.model_loader.loader import ModelOptModelLoader
|
||||
|
||||
# Note: PYTHONPATH=python should be set when running tests
|
||||
|
||||
# Check if modelopt is available
|
||||
try:
|
||||
import modelopt
|
||||
|
||||
MODELOPT_AVAILABLE = True
|
||||
except ImportError:
|
||||
MODELOPT_AVAILABLE = False
|
||||
|
||||
|
||||
class TestModelOptExport(unittest.TestCase):
|
||||
"""Test suite for ModelOpt export functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# Mock distributed functionality to avoid initialization errors
|
||||
self.mock_tp_rank = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
)
|
||||
self.mock_tp_rank.start()
|
||||
|
||||
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
|
||||
self.mock_rank0_log.start()
|
||||
|
||||
# Mock logger to avoid issues
|
||||
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
|
||||
self.mock_logger.start()
|
||||
|
||||
# Mock all distributed functions that might be called
|
||||
self.mock_get_tp_group = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tp_group"
|
||||
)
|
||||
self.mock_get_tp_group.start()
|
||||
|
||||
# Mock model parallel initialization check
|
||||
self.mock_mp_is_initialized = patch(
|
||||
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
|
||||
return_value=True,
|
||||
)
|
||||
self.mock_mp_is_initialized.start()
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.export_dir = os.path.join(self.temp_dir, "exported_model")
|
||||
self.checkpoint_dir = os.path.join(self.temp_dir, "checkpoint")
|
||||
|
||||
# Mock model
|
||||
self.mock_model = Mock(spec=torch.nn.Module)
|
||||
self.mock_model.device = torch.device("cuda:0")
|
||||
|
||||
# Mock tokenizer
|
||||
self.mock_tokenizer = Mock()
|
||||
|
||||
# Mock quantization config
|
||||
self.mock_quant_cfg = Mock()
|
||||
|
||||
# Create ModelOptModelLoader instance
|
||||
self.load_config = LoadConfig()
|
||||
self.model_loader = ModelOptModelLoader(self.load_config)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
# Stop mocks
|
||||
self.mock_tp_rank.stop()
|
||||
self.mock_rank0_log.stop()
|
||||
self.mock_logger.stop()
|
||||
self.mock_get_tp_group.stop()
|
||||
self.mock_mp_is_initialized.stop()
|
||||
|
||||
def _create_mock_export_files(self, export_dir: str):
|
||||
"""Create mock export files for testing validation."""
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
|
||||
# Create config.json
|
||||
config = {
|
||||
"model_type": "test_model",
|
||||
"architectures": ["TestModel"],
|
||||
"quantization_config": {
|
||||
"quant_method": "modelopt",
|
||||
"bits": 8,
|
||||
},
|
||||
}
|
||||
with open(os.path.join(export_dir, "config.json"), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
# Create tokenizer_config.json
|
||||
tokenizer_config = {"tokenizer_class": "TestTokenizer"}
|
||||
with open(os.path.join(export_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(tokenizer_config, f)
|
||||
|
||||
# Create model file
|
||||
with open(os.path.join(export_dir, "model.safetensors"), "w") as f:
|
||||
f.write("mock_model_data")
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
@patch("sglang.srt.model_loader.loader.os.makedirs")
|
||||
@patch("modelopt.torch.export.export_hf_checkpoint")
|
||||
def test_export_modelopt_checkpoint_success(self, mock_export, mock_makedirs):
|
||||
"""Test successful model export."""
|
||||
# Arrange
|
||||
mock_export.return_value = None
|
||||
mock_makedirs.return_value = None
|
||||
|
||||
# Act
|
||||
self.model_loader._export_modelopt_checkpoint(self.mock_model, self.export_dir)
|
||||
|
||||
# Assert
|
||||
mock_makedirs.assert_called_once_with(self.export_dir, exist_ok=True)
|
||||
mock_export.assert_called_once_with(self.mock_model, export_dir=self.export_dir)
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
@patch("modelopt.torch.opt.restore")
|
||||
@patch("modelopt.torch.quantization.utils.is_quantized")
|
||||
def test_setup_quantization_with_export_from_checkpoint(
|
||||
self, mock_is_quantized, mock_restore
|
||||
):
|
||||
"""Test export functionality when restoring from checkpoint."""
|
||||
# Arrange
|
||||
mock_is_quantized.return_value = False
|
||||
mock_restore.return_value = None
|
||||
|
||||
with patch.object(
|
||||
self.model_loader, "_export_modelopt_checkpoint"
|
||||
) as mock_export:
|
||||
# Act
|
||||
self.model_loader._setup_modelopt_quantization(
|
||||
self.mock_model,
|
||||
self.mock_tokenizer,
|
||||
self.mock_quant_cfg,
|
||||
quantized_ckpt_restore_path=self.checkpoint_dir,
|
||||
export_path=self.export_dir,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_restore.assert_called_once_with(self.mock_model, self.checkpoint_dir)
|
||||
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
@patch("modelopt.torch.quantization.quantize")
|
||||
@patch("modelopt.torch.quantization.print_quant_summary")
|
||||
@patch("modelopt.torch.quantization.utils.is_quantized")
|
||||
@patch("modelopt.torch.utils.dataset_utils.get_dataset_dataloader")
|
||||
@patch("modelopt.torch.utils.dataset_utils.create_forward_loop")
|
||||
def test_setup_quantization_with_export_after_calibration(
|
||||
self,
|
||||
mock_create_loop,
|
||||
mock_get_dataloader,
|
||||
mock_is_quantized,
|
||||
mock_print_summary,
|
||||
mock_quantize,
|
||||
):
|
||||
"""Test export functionality after calibration-based quantization."""
|
||||
# Arrange
|
||||
mock_is_quantized.return_value = False
|
||||
mock_dataloader = Mock()
|
||||
mock_get_dataloader.return_value = mock_dataloader
|
||||
mock_calibrate_loop = Mock()
|
||||
mock_create_loop.return_value = mock_calibrate_loop
|
||||
mock_quantize.return_value = None
|
||||
mock_print_summary.return_value = None
|
||||
|
||||
with patch.object(
|
||||
self.model_loader, "_export_modelopt_checkpoint"
|
||||
) as mock_export:
|
||||
# Act
|
||||
self.model_loader._setup_modelopt_quantization(
|
||||
self.mock_model,
|
||||
self.mock_tokenizer,
|
||||
self.mock_quant_cfg,
|
||||
export_path=self.export_dir,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_quantize.assert_called_once_with(
|
||||
self.mock_model, self.mock_quant_cfg, forward_loop=mock_calibrate_loop
|
||||
)
|
||||
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
def test_setup_quantization_without_export(self):
|
||||
"""Test quantization setup without export path specified."""
|
||||
with patch("modelopt.torch.quantization.utils.is_quantized", return_value=True):
|
||||
# Act
|
||||
with patch.object(
|
||||
self.model_loader, "_export_modelopt_checkpoint"
|
||||
) as mock_export:
|
||||
self.model_loader._setup_modelopt_quantization(
|
||||
self.mock_model,
|
||||
self.mock_tokenizer,
|
||||
self.mock_quant_cfg,
|
||||
export_path=None, # No export path
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_export.assert_not_called()
|
||||
|
||||
def test_quantize_and_serve_config_validation(self):
|
||||
"""Test that quantize_and_serve is properly disabled."""
|
||||
# Test that quantize-and-serve mode raises NotImplementedError
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
quantization="modelopt_fp8",
|
||||
quantize_and_serve=True,
|
||||
)
|
||||
|
||||
# Verify the error message contains helpful instructions
|
||||
error_msg = str(context.exception)
|
||||
self.assertIn("disabled due to compatibility issues", error_msg)
|
||||
self.assertIn("separate quantize-then-deploy workflow", error_msg)
|
||||
|
||||
# Test invalid configuration - no quantization
|
||||
with self.assertRaises(ValueError) as context:
|
||||
ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
quantize_and_serve=True,
|
||||
)
|
||||
self.assertIn("requires ModelOpt quantization", str(context.exception))
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
def test_standard_workflow_selection(self):
|
||||
"""Test that standard workflow is selected by default."""
|
||||
with patch(
|
||||
"modelopt.torch.quantization.utils.is_quantized", return_value=False
|
||||
):
|
||||
with patch.object(
|
||||
self.model_loader, "_standard_quantization_workflow"
|
||||
) as mock_standard:
|
||||
with patch.object(self.model_loader, "_load_modelopt_base_model"):
|
||||
mock_standard.return_value = Mock()
|
||||
|
||||
# Create model config without quantize_and_serve
|
||||
model_config = ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
quantization="modelopt_fp8",
|
||||
quantize_and_serve=False,
|
||||
)
|
||||
device_config = DeviceConfig()
|
||||
|
||||
# Act
|
||||
self.model_loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_standard.assert_called_once_with(model_config, device_config)
|
||||
|
||||
def _get_export_info(self, export_dir: str) -> dict:
|
||||
"""Get information about an exported model."""
|
||||
if not self._validate_export(export_dir):
|
||||
return None
|
||||
|
||||
try:
|
||||
config_path = os.path.join(export_dir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
return {
|
||||
"model_type": config.get("model_type", "unknown"),
|
||||
"architectures": config.get("architectures", []),
|
||||
"quantization_config": config.get("quantization_config", {}),
|
||||
"export_dir": export_dir,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
|
||||
class TestModelOptExportIntegration(unittest.TestCase):
|
||||
"""Integration tests for ModelOpt export with full model loading workflow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up integration test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.export_dir = os.path.join(self.temp_dir, "exported_model")
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up integration test fixtures."""
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.get_model_architecture")
|
||||
@patch("transformers.AutoTokenizer.from_pretrained")
|
||||
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
||||
def test_full_workflow_with_export(self, mock_model, mock_tokenizer, mock_arch):
|
||||
"""Test the complete workflow from model config to export."""
|
||||
# Arrange
|
||||
mock_arch.return_value = ("TestModel", "TestConfig")
|
||||
mock_tokenizer.return_value = Mock()
|
||||
mock_model.return_value = Mock(spec=torch.nn.Module)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
modelopt_quant="fp8",
|
||||
modelopt_export_path=self.export_dir,
|
||||
)
|
||||
|
||||
load_config = LoadConfig()
|
||||
device_config = DeviceConfig()
|
||||
|
||||
# Mock the quantization and export process
|
||||
with patch.object(
|
||||
ModelOptModelLoader, "_setup_modelopt_quantization"
|
||||
) as mock_setup:
|
||||
with patch.object(
|
||||
ModelOptModelLoader, "_load_modelopt_base_model"
|
||||
) as mock_load_base:
|
||||
mock_load_base.return_value = mock_model.return_value
|
||||
|
||||
# Act
|
||||
model_loader = ModelOptModelLoader(load_config)
|
||||
result = model_loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
)
|
||||
|
||||
# Assert
|
||||
self.assertIsNotNone(result)
|
||||
mock_setup.assert_called_once()
|
||||
# Verify export_path was passed to setup
|
||||
args, kwargs = mock_setup.call_args
|
||||
self.assertEqual(kwargs.get("export_path"), self.export_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -12,8 +12,17 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
# Add the sglang path for testing
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
|
||||
# Note: PYTHONPATH=python should be set when running tests
|
||||
|
||||
# Constants for calibration parameters to avoid hard-coded values
|
||||
CALIBRATION_BATCH_SIZE = 36
|
||||
CALIBRATION_NUM_SAMPLES = 512
|
||||
DEFAULT_DEVICE = "cuda:0"
|
||||
|
||||
# Constants for calibration parameters to avoid hard-coded values
|
||||
CALIBRATION_BATCH_SIZE = 36
|
||||
CALIBRATION_NUM_SAMPLES = 512
|
||||
DEFAULT_DEVICE = "cuda:0"
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
@@ -28,18 +37,63 @@ class TestModelOptModelLoader(CustomTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# Mock distributed functionality to avoid initialization errors
|
||||
self.mock_tp_rank = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
)
|
||||
self.mock_tp_rank.start()
|
||||
|
||||
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
|
||||
self.mock_rank0_log.start()
|
||||
|
||||
# Mock logger to avoid issues
|
||||
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
|
||||
self.mock_logger.start()
|
||||
|
||||
# Mock all distributed functions that might be called
|
||||
self.mock_get_tp_group = patch(
|
||||
"sglang.srt.distributed.parallel_state.get_tp_group"
|
||||
)
|
||||
self.mock_get_tp_group.start()
|
||||
|
||||
# Mock model parallel initialization check
|
||||
self.mock_mp_is_initialized = patch(
|
||||
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
|
||||
return_value=True,
|
||||
)
|
||||
self.mock_mp_is_initialized.start()
|
||||
|
||||
self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
self.load_config = LoadConfig()
|
||||
self.device_config = DeviceConfig(device="cuda")
|
||||
|
||||
# Create a basic model config with modelopt_quant
|
||||
# Create a basic model config with unified quantization flag
|
||||
self.model_config = ModelConfig(
|
||||
model_path=self.model_path, modelopt_quant="fp8"
|
||||
model_path=self.model_path,
|
||||
quantization="modelopt_fp8", # Use unified quantization approach
|
||||
)
|
||||
|
||||
# Also create a unified quantization config for new tests
|
||||
self.unified_model_config = ModelConfig(
|
||||
model_path=self.model_path, quantization="modelopt_fp8"
|
||||
)
|
||||
|
||||
# Mock base model
|
||||
self.mock_base_model = MagicMock(spec=nn.Module)
|
||||
self.mock_base_model.eval.return_value = self.mock_base_model
|
||||
self.mock_base_model.device = (
|
||||
DEFAULT_DEVICE # Add device attribute for calibration tests
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
# Stop mocks
|
||||
self.mock_tp_rank.stop()
|
||||
self.mock_rank0_log.stop()
|
||||
self.mock_logger.stop()
|
||||
self.mock_get_tp_group.stop()
|
||||
self.mock_mp_is_initialized.stop()
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
@@ -66,7 +120,7 @@ class TestModelOptModelLoader(CustomTestCase):
|
||||
model = self.mock_base_model
|
||||
|
||||
# Simulate the quantization config lookup
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
quant_choice_str = model_config._get_modelopt_quant_type()
|
||||
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
||||
|
||||
if not quant_cfg_name:
|
||||
@@ -123,6 +177,305 @@ class TestModelOptModelLoader(CustomTestCase):
|
||||
# Verify we get back the expected model
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_missing_modelopt_import(self, mock_logger):
|
||||
"""Test error handling when modelopt library is not available."""
|
||||
|
||||
loader = ModelOptModelLoader(self.load_config)
|
||||
|
||||
# Mock the base model loader method
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
# Simulate missing modelopt by making import fail
|
||||
original_import = __import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name.startswith("modelopt"):
|
||||
raise ImportError("No module named 'modelopt'")
|
||||
# Return default import behavior for other modules
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=mock_import):
|
||||
# Expect ImportError to be raised and logged
|
||||
with self.assertRaises(ImportError):
|
||||
loader.load_model(
|
||||
model_config=self.model_config, device_config=self.device_config
|
||||
)
|
||||
|
||||
# Verify error logging
|
||||
mock_logger.error.assert_called_with(
|
||||
"NVIDIA Model Optimizer (modelopt) library not found. "
|
||||
"Please install it to use ModelOpt quantization."
|
||||
)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_calibration_workflow_integration(self, mock_logger, mock_auto_tokenizer):
|
||||
"""Test end-to-end calibration workflow integration."""
|
||||
|
||||
loader = ModelOptModelLoader(self.load_config)
|
||||
|
||||
# Mock tokenizer
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.padding_side = "right"
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Mock modelopt modules
|
||||
mock_mtq = MagicMock()
|
||||
mock_mto = MagicMock()
|
||||
mock_dataset_utils = MagicMock()
|
||||
|
||||
# Configure quantization config
|
||||
mock_fp8_cfg = MagicMock()
|
||||
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||
|
||||
# Configure dataset utilities
|
||||
mock_calib_dataloader = MagicMock()
|
||||
mock_calibrate_loop = MagicMock()
|
||||
mock_dataset_utils.get_dataset_dataloader.return_value = mock_calib_dataloader
|
||||
mock_dataset_utils.create_forward_loop.return_value = mock_calibrate_loop
|
||||
|
||||
# Configure model as not quantized initially
|
||||
mock_is_quantized = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"modelopt": MagicMock(),
|
||||
"modelopt.torch": MagicMock(),
|
||||
"modelopt.torch.opt": mock_mto,
|
||||
"modelopt.torch.quantization": mock_mtq,
|
||||
"modelopt.torch.quantization.utils": MagicMock(
|
||||
is_quantized=mock_is_quantized
|
||||
),
|
||||
"modelopt.torch.utils": MagicMock(),
|
||||
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
|
||||
},
|
||||
):
|
||||
# Execute the load_model method to test the full workflow
|
||||
result_model = loader.load_model(
|
||||
model_config=self.model_config, device_config=self.device_config
|
||||
)
|
||||
|
||||
# Verify the model loading was successful
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
# Verify key calibration components were used
|
||||
# Note: We can't easily verify the exact calls due to dynamic imports,
|
||||
# but we can verify the workflow completed successfully
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_quantized_checkpoint_restore(self, mock_logger, mock_auto_tokenizer):
|
||||
"""Test restoring from a quantized checkpoint."""
|
||||
|
||||
# Create model config with checkpoint restore path
|
||||
config_with_restore = ModelConfig(
|
||||
model_path=self.model_path,
|
||||
quantization="modelopt_fp8",
|
||||
)
|
||||
|
||||
# Create load config with checkpoint restore path
|
||||
load_config_with_restore = LoadConfig(
|
||||
modelopt_checkpoint_restore_path="/path/to/quantized/checkpoint"
|
||||
)
|
||||
|
||||
loader = ModelOptModelLoader(load_config_with_restore)
|
||||
|
||||
# Mock tokenizer
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Mock modelopt modules
|
||||
mock_mtq = MagicMock()
|
||||
mock_mto = MagicMock()
|
||||
|
||||
# Configure quantization config
|
||||
mock_fp8_cfg = MagicMock()
|
||||
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||
|
||||
# Configure model as not quantized initially
|
||||
mock_is_quantized = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"modelopt": MagicMock(),
|
||||
"modelopt.torch": MagicMock(),
|
||||
"modelopt.torch.opt": mock_mto,
|
||||
"modelopt.torch.quantization": mock_mtq,
|
||||
"modelopt.torch.quantization.utils": MagicMock(
|
||||
is_quantized=mock_is_quantized
|
||||
),
|
||||
},
|
||||
):
|
||||
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
|
||||
# Mock the _setup_modelopt_quantization to simulate checkpoint restore
|
||||
def mock_setup_quantization(
|
||||
model,
|
||||
tokenizer,
|
||||
quant_cfg,
|
||||
quantized_ckpt_restore_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
if quantized_ckpt_restore_path:
|
||||
mock_mto.restore(model, quantized_ckpt_restore_path)
|
||||
print(
|
||||
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
||||
)
|
||||
return
|
||||
|
||||
mock_setup.side_effect = mock_setup_quantization
|
||||
|
||||
# Execute the load_model method
|
||||
result_model = loader.load_model(
|
||||
model_config=config_with_restore,
|
||||
device_config=self.device_config,
|
||||
)
|
||||
|
||||
# Verify the setup was called with restore path
|
||||
mock_setup.assert_called_once()
|
||||
call_args = mock_setup.call_args
|
||||
# Check that the restore path was passed correctly
|
||||
self.assertIn("quantized_ckpt_restore_path", call_args[1])
|
||||
self.assertEqual(
|
||||
call_args[1]["quantized_ckpt_restore_path"],
|
||||
"/path/to/quantized/checkpoint",
|
||||
)
|
||||
|
||||
# Verify restore was called
|
||||
mock_mto.restore.assert_called_once_with(
|
||||
self.mock_base_model, "/path/to/quantized/checkpoint"
|
||||
)
|
||||
|
||||
# Verify we get the expected model back
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
|
||||
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
|
||||
@patch("sglang.srt.model_loader.loader.logger")
|
||||
def test_quantized_checkpoint_save(self, mock_logger, mock_auto_tokenizer):
|
||||
"""Test saving quantized checkpoint after calibration."""
|
||||
|
||||
# Create model config with checkpoint save path
|
||||
config_with_save = ModelConfig(
|
||||
model_path=self.model_path,
|
||||
quantization="modelopt_fp8",
|
||||
)
|
||||
|
||||
# Create load config with checkpoint save path
|
||||
load_config_with_save = LoadConfig(
|
||||
modelopt_checkpoint_save_path="/path/to/save/checkpoint"
|
||||
)
|
||||
|
||||
loader = ModelOptModelLoader(load_config_with_save)
|
||||
|
||||
# Mock tokenizer
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Mock modelopt modules
|
||||
mock_mtq = MagicMock()
|
||||
mock_mto = MagicMock()
|
||||
mock_dataset_utils = MagicMock()
|
||||
|
||||
# Configure quantization config
|
||||
mock_fp8_cfg = MagicMock()
|
||||
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
|
||||
|
||||
# Configure model as not quantized initially
|
||||
mock_is_quantized = MagicMock(return_value=False)
|
||||
|
||||
with patch.object(
|
||||
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
|
||||
):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"modelopt": MagicMock(),
|
||||
"modelopt.torch": MagicMock(),
|
||||
"modelopt.torch.opt": mock_mto,
|
||||
"modelopt.torch.quantization": mock_mtq,
|
||||
"modelopt.torch.quantization.utils": MagicMock(
|
||||
is_quantized=mock_is_quantized
|
||||
),
|
||||
"modelopt.torch.utils": MagicMock(),
|
||||
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
|
||||
},
|
||||
):
|
||||
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
|
||||
# Mock the _setup_modelopt_quantization to simulate checkpoint save
|
||||
def mock_setup_quantization(
|
||||
model,
|
||||
tokenizer,
|
||||
quant_cfg,
|
||||
quantized_ckpt_save_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Simulate calibration and quantization
|
||||
mock_mtq.quantize(model, quant_cfg, forward_loop=MagicMock())
|
||||
mock_mtq.print_quant_summary(model)
|
||||
|
||||
# Save checkpoint if path provided
|
||||
if quantized_ckpt_save_path:
|
||||
mock_mto.save(model, quantized_ckpt_save_path)
|
||||
print(
|
||||
f"Quantized model saved to {quantized_ckpt_save_path}"
|
||||
)
|
||||
|
||||
mock_setup.side_effect = mock_setup_quantization
|
||||
|
||||
# Execute the load_model method
|
||||
result_model = loader.load_model(
|
||||
model_config=config_with_save, device_config=self.device_config
|
||||
)
|
||||
|
||||
# Verify the setup was called with save path
|
||||
mock_setup.assert_called_once()
|
||||
call_args = mock_setup.call_args
|
||||
# Check that the save path was passed correctly
|
||||
self.assertIn("quantized_ckpt_save_path", call_args[1])
|
||||
self.assertEqual(
|
||||
call_args[1]["quantized_ckpt_save_path"],
|
||||
"/path/to/save/checkpoint",
|
||||
)
|
||||
|
||||
# Verify save was called
|
||||
mock_mto.save.assert_called_once_with(
|
||||
self.mock_base_model, "/path/to/save/checkpoint"
|
||||
)
|
||||
|
||||
# Verify we get the expected model back
|
||||
self.assertEqual(result_model, self.mock_base_model)
|
||||
|
||||
def test_unified_quantization_flag_support(self):
|
||||
"""Test that ModelOptModelLoader supports unified quantization flags."""
|
||||
# Test modelopt_fp8
|
||||
config_fp8 = ModelConfig(
|
||||
model_path=self.model_path, quantization="modelopt_fp8"
|
||||
)
|
||||
self.assertEqual(config_fp8._get_modelopt_quant_type(), "fp8")
|
||||
|
||||
# Test modelopt_fp4
|
||||
config_fp4 = ModelConfig(
|
||||
model_path=self.model_path, quantization="modelopt_fp4"
|
||||
)
|
||||
self.assertEqual(config_fp4._get_modelopt_quant_type(), "nvfp4")
|
||||
|
||||
# Test auto-detection
|
||||
config_auto = ModelConfig(model_path=self.model_path, quantization="modelopt")
|
||||
# Should default to fp8 when no config is detected
|
||||
self.assertEqual(config_auto._get_modelopt_quant_type(), "fp8")
|
||||
|
||||
|
||||
class TestModelOptLoaderIntegration(CustomTestCase):
|
||||
"""Integration tests for ModelOptModelLoader with Engine API."""
|
||||
|
||||
Reference in New Issue
Block a user