From 2a5f0100e0bfd4d33de25020590e46d4c1e95eb5 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 10 Jun 2025 21:07:20 -0700 Subject: [PATCH] Fix GGuf and add back test_gguf.py (#7067) --- python/sglang/srt/layers/linear.py | 4 ---- python/sglang/srt/model_loader/loader.py | 9 ++++++++- test/srt/run_suite.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 212105621..f664d6bb6 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -546,8 +546,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) - if len(param.data_container) == 2: - self.qweight = param.materialize_nested() return param_data = param.data @@ -961,8 +959,6 @@ class QKVParallelLinear(ColumnParallelLinear): param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) - if len(param.data_container) == 3: - self.qweight = param.materialize_nested() return param_data = param.data diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index e2c6a3767..6ba31b515 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -1259,12 +1259,19 @@ class GGUFModelLoader(BaseModelLoader): ): model_config.hf_config.update({"tie_word_embeddings": True}) + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(model_config, self.load_config) model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map) ) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) return model diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b1883e1a9..83e1b92c1 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -186,7 +186,7 @@ suites = { "vllm_dependency_test": [ TestFile("test_awq.py"), TestFile("test_bnb.py"), - # TestFile("test_gguf.py", 78), # TODO: Fix GGuf after updating to torch 2.7 and vllm 0.9 + TestFile("test_gguf.py", 78), TestFile("test_gptqmodel_dynamic.py", 72), TestFile("test_vllm_dependency.py"), ],