[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -15,7 +15,7 @@
|
||||
# Adapted from llama2.py
|
||||
# Modify details for the adaptation of Qwen2 model.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
|
||||
from readline import add_history
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
kv_cache_scales_loader,
|
||||
)
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
Qwen2Config = None
|
||||
|
||||
@@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 32768,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
config: Qwen2Config,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -228,6 +238,7 @@ class Qwen2Model(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -237,6 +248,7 @@ class Qwen2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -244,7 +256,9 @@ class Qwen2Model(nn.Module):
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
Reference in New Issue
Block a user