[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)

Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
Qubitium-ModelCloud
2025-03-05 17:11:00 +08:00
committed by GitHub
parent 583d6af71b
commit 56a724eba3
56 changed files with 1988 additions and 282 deletions

View File

@@ -63,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_cuda_available, is_hip
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
is_hip_ = is_hip()
@@ -79,10 +79,15 @@ class DeepseekV2MLP(nn.Module):
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
@@ -90,6 +95,7 @@ class DeepseekV2MLP(nn.Module):
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
)
if hidden_act != "silu":
raise ValueError(
@@ -106,7 +112,11 @@ class DeepseekV2MLP(nn.Module):
class MoEGate(nn.Module):
def __init__(self, config):
def __init__(
self,
config,
prefix: str = "",
):
super().__init__()
self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size))
@@ -129,6 +139,7 @@ class DeepseekV2MoE(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
@@ -147,7 +158,7 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self.gate = MoEGate(config=config)
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
@@ -161,6 +172,7 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
)
if config.n_shared_experts is not None:
@@ -171,6 +183,7 @@ class DeepseekV2MoE(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -217,6 +230,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
@@ -241,6 +255,7 @@ class DeepseekV2Attention(nn.Module):
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
@@ -248,6 +263,7 @@ class DeepseekV2Attention(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ColumnParallelLinear(
@@ -255,6 +271,7 @@ class DeepseekV2Attention(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -262,8 +279,7 @@ class DeepseekV2Attention(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
# FIXME: quick fix for skip quantization
prefix=f"self_attn.kv_a_proj_with_mqa",
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
@@ -271,6 +287,7 @@ class DeepseekV2Attention(nn.Module):
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = RowParallelLinear(
@@ -278,6 +295,7 @@ class DeepseekV2Attention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper(
@@ -303,6 +321,7 @@ class DeepseekV2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
def forward(
@@ -368,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
use_dp=False,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
@@ -394,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ReplicatedLinear(
@@ -401,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ReplicatedLinear(
@@ -408,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = ReplicatedLinear(
@@ -421,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
else:
# For tensor parallel attention
@@ -430,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
@@ -437,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ColumnParallelLinear(
@@ -444,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = RowParallelLinear(
@@ -457,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -464,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
# FIXME: quick fix for skip quantization
prefix=f"self_attn.kv_a_proj_with_mqa",
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
@@ -496,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=1,
layer_id=layer_id,
v_head_dim=self.kv_lora_rank,
prefix=add_prefix("attn_mqa", prefix),
)
self.attn_mha = RadixAttention(
@@ -505,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
v_head_dim=self.v_head_dim,
prefix=add_prefix("attn_mha", prefix),
)
self.w_kc = None
@@ -848,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module):
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
is_nextn: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -880,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config,
layer_id=layer_id,
use_dp=self.enable_dp_attention,
prefix=add_prefix("self_attn", prefix),
)
else:
self.self_attn = DeepseekV2Attention(
@@ -898,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
prefix=add_prefix("self_attn", prefix),
)
if is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
@@ -962,6 +1001,7 @@ class DeepseekV2Model(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_id = config.pad_token_id
@@ -978,6 +1018,7 @@ class DeepseekV2Model(nn.Module):
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
)
for layer_id in range(config.num_hidden_layers)
]
@@ -1008,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekV2Model(config, quant_config)
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)