Enable native ModelOpt quantization support (3/3) (#10154)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu
2025-10-21 21:44:29 -07:00
committed by GitHub
parent 4b65ed42cc
commit 80b2b3207a
16 changed files with 1528 additions and 39 deletions

View File

@@ -12,8 +12,17 @@ from unittest.mock import MagicMock, patch
import torch.nn as nn
# Add the sglang path for testing
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
# Note: PYTHONPATH=python should be set when running tests
# Constants for calibration parameters to avoid hard-coded values
CALIBRATION_BATCH_SIZE = 36
CALIBRATION_NUM_SAMPLES = 512
DEFAULT_DEVICE = "cuda:0"
# Constants for calibration parameters to avoid hard-coded values
CALIBRATION_BATCH_SIZE = 36
CALIBRATION_NUM_SAMPLES = 512
DEFAULT_DEVICE = "cuda:0"
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
@@ -28,18 +37,63 @@ class TestModelOptModelLoader(CustomTestCase):
def setUp(self):
"""Set up test fixtures."""
# Mock distributed functionality to avoid initialization errors
self.mock_tp_rank = patch(
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
return_value=0,
)
self.mock_tp_rank.start()
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
self.mock_rank0_log.start()
# Mock logger to avoid issues
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
self.mock_logger.start()
# Mock all distributed functions that might be called
self.mock_get_tp_group = patch(
"sglang.srt.distributed.parallel_state.get_tp_group"
)
self.mock_get_tp_group.start()
# Mock model parallel initialization check
self.mock_mp_is_initialized = patch(
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
return_value=True,
)
self.mock_mp_is_initialized.start()
self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.load_config = LoadConfig()
self.device_config = DeviceConfig(device="cuda")
# Create a basic model config with modelopt_quant
# Create a basic model config with unified quantization flag
self.model_config = ModelConfig(
model_path=self.model_path, modelopt_quant="fp8"
model_path=self.model_path,
quantization="modelopt_fp8", # Use unified quantization approach
)
# Also create a unified quantization config for new tests
self.unified_model_config = ModelConfig(
model_path=self.model_path, quantization="modelopt_fp8"
)
# Mock base model
self.mock_base_model = MagicMock(spec=nn.Module)
self.mock_base_model.eval.return_value = self.mock_base_model
self.mock_base_model.device = (
DEFAULT_DEVICE # Add device attribute for calibration tests
)
def tearDown(self):
"""Clean up test fixtures."""
# Stop mocks
self.mock_tp_rank.stop()
self.mock_rank0_log.stop()
self.mock_logger.stop()
self.mock_get_tp_group.stop()
self.mock_mp_is_initialized.stop()
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.logger")
@@ -66,7 +120,7 @@ class TestModelOptModelLoader(CustomTestCase):
model = self.mock_base_model
# Simulate the quantization config lookup
quant_choice_str = model_config.modelopt_quant
quant_choice_str = model_config._get_modelopt_quant_type()
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
if not quant_cfg_name:
@@ -123,6 +177,305 @@ class TestModelOptModelLoader(CustomTestCase):
# Verify we get back the expected model
self.assertEqual(result_model, self.mock_base_model)
@patch("sglang.srt.model_loader.loader.logger")
def test_missing_modelopt_import(self, mock_logger):
"""Test error handling when modelopt library is not available."""
loader = ModelOptModelLoader(self.load_config)
# Mock the base model loader method
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
# Simulate missing modelopt by making import fail
original_import = __import__
def mock_import(name, *args, **kwargs):
if name.startswith("modelopt"):
raise ImportError("No module named 'modelopt'")
# Return default import behavior for other modules
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=mock_import):
# Expect ImportError to be raised and logged
with self.assertRaises(ImportError):
loader.load_model(
model_config=self.model_config, device_config=self.device_config
)
# Verify error logging
mock_logger.error.assert_called_with(
"NVIDIA Model Optimizer (modelopt) library not found. "
"Please install it to use ModelOpt quantization."
)
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
@patch("sglang.srt.model_loader.loader.logger")
def test_calibration_workflow_integration(self, mock_logger, mock_auto_tokenizer):
"""Test end-to-end calibration workflow integration."""
loader = ModelOptModelLoader(self.load_config)
# Mock tokenizer
mock_tokenizer = MagicMock()
mock_tokenizer.padding_side = "right"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# Mock modelopt modules
mock_mtq = MagicMock()
mock_mto = MagicMock()
mock_dataset_utils = MagicMock()
# Configure quantization config
mock_fp8_cfg = MagicMock()
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
# Configure dataset utilities
mock_calib_dataloader = MagicMock()
mock_calibrate_loop = MagicMock()
mock_dataset_utils.get_dataset_dataloader.return_value = mock_calib_dataloader
mock_dataset_utils.create_forward_loop.return_value = mock_calibrate_loop
# Configure model as not quantized initially
mock_is_quantized = MagicMock(return_value=False)
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
with patch.dict(
"sys.modules",
{
"modelopt": MagicMock(),
"modelopt.torch": MagicMock(),
"modelopt.torch.opt": mock_mto,
"modelopt.torch.quantization": mock_mtq,
"modelopt.torch.quantization.utils": MagicMock(
is_quantized=mock_is_quantized
),
"modelopt.torch.utils": MagicMock(),
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
},
):
# Execute the load_model method to test the full workflow
result_model = loader.load_model(
model_config=self.model_config, device_config=self.device_config
)
# Verify the model loading was successful
self.assertEqual(result_model, self.mock_base_model)
# Verify key calibration components were used
# Note: We can't easily verify the exact calls due to dynamic imports,
# but we can verify the workflow completed successfully
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
@patch("sglang.srt.model_loader.loader.logger")
def test_quantized_checkpoint_restore(self, mock_logger, mock_auto_tokenizer):
"""Test restoring from a quantized checkpoint."""
# Create model config with checkpoint restore path
config_with_restore = ModelConfig(
model_path=self.model_path,
quantization="modelopt_fp8",
)
# Create load config with checkpoint restore path
load_config_with_restore = LoadConfig(
modelopt_checkpoint_restore_path="/path/to/quantized/checkpoint"
)
loader = ModelOptModelLoader(load_config_with_restore)
# Mock tokenizer
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# Mock modelopt modules
mock_mtq = MagicMock()
mock_mto = MagicMock()
# Configure quantization config
mock_fp8_cfg = MagicMock()
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
# Configure model as not quantized initially
mock_is_quantized = MagicMock(return_value=False)
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
with patch.dict(
"sys.modules",
{
"modelopt": MagicMock(),
"modelopt.torch": MagicMock(),
"modelopt.torch.opt": mock_mto,
"modelopt.torch.quantization": mock_mtq,
"modelopt.torch.quantization.utils": MagicMock(
is_quantized=mock_is_quantized
),
},
):
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
# Mock the _setup_modelopt_quantization to simulate checkpoint restore
def mock_setup_quantization(
model,
tokenizer,
quant_cfg,
quantized_ckpt_restore_path=None,
**kwargs,
):
if quantized_ckpt_restore_path:
mock_mto.restore(model, quantized_ckpt_restore_path)
print(
f"Restored quantized model from {quantized_ckpt_restore_path}"
)
return
mock_setup.side_effect = mock_setup_quantization
# Execute the load_model method
result_model = loader.load_model(
model_config=config_with_restore,
device_config=self.device_config,
)
# Verify the setup was called with restore path
mock_setup.assert_called_once()
call_args = mock_setup.call_args
# Check that the restore path was passed correctly
self.assertIn("quantized_ckpt_restore_path", call_args[1])
self.assertEqual(
call_args[1]["quantized_ckpt_restore_path"],
"/path/to/quantized/checkpoint",
)
# Verify restore was called
mock_mto.restore.assert_called_once_with(
self.mock_base_model, "/path/to/quantized/checkpoint"
)
# Verify we get the expected model back
self.assertEqual(result_model, self.mock_base_model)
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
@patch("sglang.srt.model_loader.loader.logger")
def test_quantized_checkpoint_save(self, mock_logger, mock_auto_tokenizer):
"""Test saving quantized checkpoint after calibration."""
# Create model config with checkpoint save path
config_with_save = ModelConfig(
model_path=self.model_path,
quantization="modelopt_fp8",
)
# Create load config with checkpoint save path
load_config_with_save = LoadConfig(
modelopt_checkpoint_save_path="/path/to/save/checkpoint"
)
loader = ModelOptModelLoader(load_config_with_save)
# Mock tokenizer
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# Mock modelopt modules
mock_mtq = MagicMock()
mock_mto = MagicMock()
mock_dataset_utils = MagicMock()
# Configure quantization config
mock_fp8_cfg = MagicMock()
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
# Configure model as not quantized initially
mock_is_quantized = MagicMock(return_value=False)
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
with patch.dict(
"sys.modules",
{
"modelopt": MagicMock(),
"modelopt.torch": MagicMock(),
"modelopt.torch.opt": mock_mto,
"modelopt.torch.quantization": mock_mtq,
"modelopt.torch.quantization.utils": MagicMock(
is_quantized=mock_is_quantized
),
"modelopt.torch.utils": MagicMock(),
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
},
):
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
# Mock the _setup_modelopt_quantization to simulate checkpoint save
def mock_setup_quantization(
model,
tokenizer,
quant_cfg,
quantized_ckpt_save_path=None,
**kwargs,
):
# Simulate calibration and quantization
mock_mtq.quantize(model, quant_cfg, forward_loop=MagicMock())
mock_mtq.print_quant_summary(model)
# Save checkpoint if path provided
if quantized_ckpt_save_path:
mock_mto.save(model, quantized_ckpt_save_path)
print(
f"Quantized model saved to {quantized_ckpt_save_path}"
)
mock_setup.side_effect = mock_setup_quantization
# Execute the load_model method
result_model = loader.load_model(
model_config=config_with_save, device_config=self.device_config
)
# Verify the setup was called with save path
mock_setup.assert_called_once()
call_args = mock_setup.call_args
# Check that the save path was passed correctly
self.assertIn("quantized_ckpt_save_path", call_args[1])
self.assertEqual(
call_args[1]["quantized_ckpt_save_path"],
"/path/to/save/checkpoint",
)
# Verify save was called
mock_mto.save.assert_called_once_with(
self.mock_base_model, "/path/to/save/checkpoint"
)
# Verify we get the expected model back
self.assertEqual(result_model, self.mock_base_model)
def test_unified_quantization_flag_support(self):
"""Test that ModelOptModelLoader supports unified quantization flags."""
# Test modelopt_fp8
config_fp8 = ModelConfig(
model_path=self.model_path, quantization="modelopt_fp8"
)
self.assertEqual(config_fp8._get_modelopt_quant_type(), "fp8")
# Test modelopt_fp4
config_fp4 = ModelConfig(
model_path=self.model_path, quantization="modelopt_fp4"
)
self.assertEqual(config_fp4._get_modelopt_quant_type(), "nvfp4")
# Test auto-detection
config_auto = ModelConfig(model_path=self.model_path, quantization="modelopt")
# Should default to fp8 when no config is detected
self.assertEqual(config_auto._get_modelopt_quant_type(), "fp8")
class TestModelOptLoaderIntegration(CustomTestCase):
"""Integration tests for ModelOptModelLoader with Engine API."""