From 83123f481ebc3bcbdff8b72a5ca293f2ee9a707a Mon Sep 17 00:00:00 2001 From: ichernob Date: Tue, 12 Aug 2025 23:31:18 +0300 Subject: [PATCH] [Quantization] Supported w8a8 int8 quantized Gemma3 and Qwen-VL models (#8619) Co-authored-by: ronnie_zheng --- .../srt/layers/quantization/w8a8_int8.py | 12 +- python/sglang/srt/model_loader/loader.py | 24 +++- test/srt/test_ascend_w8a8_quantization.py | 104 ++++++++++++++++++ 3 files changed, 131 insertions(+), 9 deletions(-) create mode 100644 test/srt/test_ascend_w8a8_quantization.py diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 4e33d4be8..843fffe7b 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -255,17 +255,23 @@ class W8A8Int8Config(QuantizationConfig): if _is_npu: if isinstance(layer, LinearBase): + key = "model" + if "vision_model" in prefix: + key = "vision_model" + elif "visual" in prefix: + key = "visual" + packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {}) prefix_in_quant_config = prefix proj_name = prefix.split(".")[-1] - if proj_name in self.packed_modules_mapping: + if proj_name in packed_modules_mapping_subset: prefix_in_quant_config = prefix.replace( - proj_name, self.packed_modules_mapping[proj_name][0] + proj_name, packed_modules_mapping_subset[proj_name][0] ) self.is_dynamic = ( self.quant_description[prefix_in_quant_config + ".weight"] == "W8A8_DYNAMIC" ) - if self.is_layer_skipped(prefix, self.packed_modules_mapping): + if self.is_layer_skipped(prefix, packed_modules_mapping_subset): return UnquantizedLinearMethod() return ( NPU_W8A8DynamicLinearMethod(self) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 2e2f71078..95d41a050 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -162,12 +162,24 @@ def _initialize_model( model_class, _ = get_model_architecture(model_config) packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {}) if _is_npu: - packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [ - "q_a_proj", - "kv_a_proj_with_mqa", - ] - packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] - packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] + packed_modules_mapping.update( + { + "visual": {"qkv_proj": ["qkv"]}, + "vision_model": { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "proj": ["out_proj"], + }, + "model": { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "fused_qkv_a_proj_with_mqa": [ + "q_a_proj", + "kv_a_proj_with_mqa", + ], + }, + } + ) + quant_config = _get_quantization_config( model_config, load_config, packed_modules_mapping ) diff --git a/test/srt/test_ascend_w8a8_quantization.py b/test/srt/test_ascend_w8a8_quantization.py new file mode 100644 index 000000000..bf139f46a --- /dev/null +++ b/test/srt/test_ascend_w8a8_quantization.py @@ -0,0 +1,104 @@ +""" +Usage: +python3 -m unittest test_ascend_w8a8_quantization.TestAscendW8A8.test_gsm8k +""" + +import os +import time +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + +if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 +) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" + + +class TestAscendW8A8(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--disable-cuda-graph", + "--device", + "npu", + "--attention-backend", + "ascend", + "--quantization", + "w8a8_int8", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + metrics = run_eval(args) + print(metrics) + + self.assertGreaterEqual(metrics["accuracy"], 0.25) + self.assertGreaterEqual(metrics["output_throughput"], 1000) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.perf_counter() + res = self.run_decode(max_tokens) + tok = time.perf_counter() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + + if is_in_ci(): + self.assertGreaterEqual(throughput, 25) + + +if __name__ == "__main__": + unittest.main()