From 30b404ce72b52e02076fa46ff5ee16f3e1a68a98 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 13 Sep 2024 23:46:55 -0700 Subject: [PATCH] Add torchao quant for mixtral and qwen_moe (#1418) --- python/sglang/srt/layers/torchao_utils.py | 39 ++++++++++++++++++++++- python/sglang/srt/models/llama.py | 21 ++---------- python/sglang/srt/models/mixtral.py | 5 +++ python/sglang/srt/models/qwen2_moe.py | 5 +++ 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index bc7bde86e..51a307e5c 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -2,10 +2,20 @@ Common utilities for torchao. """ +from typing import Dict, Set + import torch -def torchao_quantize_param_data(param, torchao_config): +def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): + """Quantize a Tensor with torchao quantization specified by torchao_config + + Args: + `param`: weight parameter of the linear module + `torchao_config`: type of quantization and their arguments we want to use to + quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size + 128 + """ # Lazy import to suppress some warnings from torchao.quantization import ( int4_weight_only, @@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config): # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 quantize_(dummy_linear, float8_weight_only()) return dummy_linear.weight + + +def apply_torchao_config_( + self: torch.nn.Module, + params_dict: Dict[str, torch.Tensor], + param_suffixes: Set[str], +) -> None: + """A util function used for quantizing the weight parameters after they are loaded if + self.torchao_config is specified + + Args: + `self`: the model we want to quantize + `params_dict`: dictionary mapping from param_name to the parameter Tensor + `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes + + Returns: + None, the `params_dict` is modified inplace and the weights of `self` model are quantized + """ + if self.torchao_config: + for param_suffix in param_suffixes: + for name in params_dict: + param = params_dict[name] + if param_suffix in name and param.ndim == 2: + params_dict[name] = torchao_quantize_param_data( + param, self.torchao_config + ) + self.load_state_dict(params_dict, assign=True) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index b7842f192..57c88f226 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -41,7 +41,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.torchao_utils import torchao_quantize_param_data +from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -405,24 +405,7 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if self.torchao_config: - if name.endswith("proj.weight") and param.ndim == 2: - params_dict[name] = torchao_quantize_param_data( - param, self.torchao_config - ) - - if self.torchao_config: - # quantizing the loaded, stacked params, e.g. "...qkv_proj" - stacked_params = set(entry[0] for entry in stacked_params_mapping) - for param_suffix in stacked_params: - for name in params_dict: - if param_suffix in name: - param = params_dict[name] - params_dict[name] = torchao_quantize_param_data( - param, self.torchao_config - ) - - self.load_state_dict(params_dict, assign=True) + apply_torchao_config_(self, params_dict, set(["proj.weight"])) class Phi3ForCausalLM(LlamaForCausalLM): diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 87e3bb030..d7c232ec1 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.torchao_utils import apply_torchao_config_ +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -296,6 +298,7 @@ class MixtralForCausalLM(nn.Module): super().__init__() self.config = config self.quant_config = quant_config + self.torchao_config = global_server_args_dict["torchao_config"] self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -376,5 +379,7 @@ class MixtralForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) + apply_torchao_config_(self, params_dict, set(["proj.weight"])) + EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 1ff2190ed..d47589ddc 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -47,6 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.torchao_utils import apply_torchao_config_ +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -359,6 +361,7 @@ class Qwen2MoeForCausalLM(nn.Module): super().__init__() self.config = config self.quant_config = quant_config + self.torchao_config = global_server_args_dict["torchao_config"] self.model = Qwen2MoeModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config @@ -451,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) + apply_torchao_config_(self, params_dict, set(["proj.weight"])) + EntryClass = Qwen2MoeForCausalLM