diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index fb569890e..dabb515ed 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -19,10 +19,9 @@ class RadixAttention(nn.Module): head_dim, scaling, num_kv_heads, - layer_id, + layer_id ): super().__init__() - self.tp_q_head_num = num_heads self.tp_k_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 7d72c6c70..93e99fe23 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -12,10 +12,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.utils import is_multimodal_model from sglang.utils import get_available_gpu_memory from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel import sglang +QUANTIONCONFIG_MAPPING = {'awq': AWQConfig, + 'gptq': GPTQConfig} logger = logging.getLogger("model_runner") @@ -280,8 +283,10 @@ class ModelRunner: self.model_config.hf_config, "quantization_config", None ) if hf_quant_config is not None: - # TODO: config quantization awq etc - quant_config = AWQConfig.from_config(hf_quant_config) + quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method']) + if quant_config_class is None: + raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}") + quant_config = quant_config_class.from_config(hf_quant_config) logger.info(f"quant_config: {quant_config}") linear_method = quant_config.get_linear_method() model = model_class( diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index c651ea908..3089c7ac0 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -34,6 +34,7 @@ class QWenMLP(nn.Module): hidden_size: int, intermediate_size: int, hidden_act: str = "silu", + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -41,12 +42,14 @@ class QWenMLP(nn.Module): 2 * [intermediate_size], bias=False, gather_output=False, + linear_method=linear_method ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, input_is_parallel=True, + linear_method=linear_method ) if hidden_act != "silu": raise ValueError( @@ -71,6 +74,7 @@ class QWenAttention(nn.Module): layer_id: int = 0, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, + linear_method: Optional[LinearMethodBase] = None ): super().__init__() self.hidden_size = hidden_size @@ -82,13 +86,18 @@ class QWenAttention(nn.Module): # pylint: disable=invalid-name self.c_attn = QKVParallelLinear( - hidden_size, self.head_dim, self.total_num_heads, bias=True + hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + linear_method=linear_method ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, + linear_method=linear_method ) self.rotary_emb = get_rope( self.head_dim, @@ -121,7 +130,7 @@ class QWenAttention(nn.Module): class QWenBlock(nn.Module): - def __init__(self, config: QWenConfig, layer_id): + def __init__(self, config: QWenConfig, layer_id, linear_method=None): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -134,11 +143,12 @@ class QWenBlock(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, layer_id=layer_id, + linear_method=linear_method ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2) + self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method) def forward( self, @@ -165,7 +175,7 @@ class QWenBlock(nn.Module): class QWenModel(nn.Module): - def __init__(self, config: QWenConfig): + def __init__(self, config: QWenConfig, linear_method=None): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -176,7 +186,7 @@ class QWenModel(nn.Module): config.hidden_size, ) self.h = nn.ModuleList( - [QWenBlock(config, i) for i in range(config.num_hidden_layers)] + [QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)] ) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -202,7 +212,7 @@ class QWenLMHeadModel(nn.Module): def __init__(self, config: QWenConfig, linear_method=None): super().__init__() self.config = config - self.transformer = QWenModel(config) + self.transformer = QWenModel(config, linear_method=linear_method) vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -219,9 +229,6 @@ class QWenLMHeadModel(nn.Module): ) return next_tokens - _column_parallel_weights = [] - _row_parallel_weights = ["c_proj.weight"] - def load_weights( self, model_name_or_path: str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 54274f366..42c534abe 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -259,4 +259,4 @@ def load_image(image_file): else: image = Image.open(BytesIO(base64.b64decode(image_file))) - return image + return image \ No newline at end of file