[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -40,7 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import add_prefix, is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import bmm_fp8
|
||||
@@ -53,6 +53,7 @@ class MiniCPM3MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -60,12 +61,14 @@ class MiniCPM3MLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -107,6 +110,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -131,6 +135,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
@@ -138,6 +143,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
@@ -145,6 +151,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -152,6 +159,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
@@ -159,6 +167,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -166,6 +175,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
@@ -182,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -250,6 +261,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -274,6 +286,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
@@ -281,6 +294,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
@@ -288,6 +302,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -295,6 +310,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
@@ -302,6 +318,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -309,6 +326,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
@@ -325,6 +343,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
num_kv_heads=1,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
self.w_kc = None
|
||||
@@ -405,6 +424,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -429,6 +449,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
else:
|
||||
self.self_attn = MiniCPM3Attention(
|
||||
@@ -447,12 +468,14 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = MiniCPM3MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -494,6 +517,7 @@ class MiniCPM3Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -503,10 +527,16 @@ class MiniCPM3Model(nn.Module):
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
|
||||
MiniCPM3DecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -542,19 +572,23 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPM3Model(config, quant_config=quant_config)
|
||||
self.model = MiniCPM3Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if not self.config.tie_word_embeddings:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||
|
||||
Reference in New Issue
Block a user