[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)

Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
Qubitium-ModelCloud
2025-03-05 17:11:00 +08:00
committed by GitHub
parent 583d6af71b
commit 56a724eba3
56 changed files with 1988 additions and 282 deletions

View File

@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class XverseMLP(nn.Module):
@@ -57,14 +58,14 @@ class XverseMLP(nn.Module):
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
prefix=add_prefix("down_proj", prefix),
)
if hidden_act != "silu":
raise ValueError(
@@ -128,14 +129,14 @@ class XverseAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
@@ -152,6 +153,7 @@ class XverseAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
def forward(
@@ -202,14 +204,14 @@ class XverseDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
prefix=add_prefix("self_attn", prefix),
)
self.mlp = XverseMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
@@ -246,6 +248,7 @@ class XverseModel(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@@ -254,11 +257,15 @@ class XverseModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = nn.ModuleList(
[
XverseDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers)
]
@@ -295,12 +302,17 @@ class XverseForCausalLM(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = XverseModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.model = XverseModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()