diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 1fdda4fad..910309da9 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -26,12 +26,11 @@ def apply_torchao_config_to_model( quantize_, ) from torchao.quantization.observer import PerRow, PerTensor - from torchao.quantization.quant_api import _is_linear if filter_fn is None: def filter_fn(module, fqn): - return _is_linear(module) and "proj" in fqn + return "proj" in fqn if torchao_config == "" or torchao_config is None: return model diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a3f62f250..db024c5c7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -157,10 +157,6 @@ class ModelRunner: self.sampler = Sampler() self.load_model() - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) - # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) if self.tp_size > 1 and supports_torch_tp: @@ -169,6 +165,10 @@ class ModelRunner: else: self.torch_tp_applied = False + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) + # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager()