[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)

Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
Qubitium-ModelCloud
2025-03-05 17:11:00 +08:00
committed by GitHub
parent 583d6af71b
commit 56a724eba3
56 changed files with 1988 additions and 282 deletions

View File

@@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
from sglang.srt.utils import add_prefix, make_layers
class Olmo2Attention(nn.Module):
@@ -60,6 +60,7 @@ class Olmo2Attention(nn.Module):
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module):
self.head_dim,
self.total_num_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.tp_rank = get_tensor_model_parallel_rank()
@@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
# Attention output projection.
@@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module):
self.head_dim * self.total_num_heads,
self.hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
def _apply_qk_norm(
@@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
# Activation function.
@@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
def forward(
@@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module):
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
# Attention block.
self.self_attn = Olmo2Attention(config, layer_id, quant_config)
self.self_attn = Olmo2Attention(
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
)
# MLP block.
self.mlp = Olmo2MLP(config, quant_config)
self.mlp = Olmo2MLP(config, quant_config, prefix=add_prefix("mlp", prefix))
# RMSNorm
self.post_attention_layernorm = RMSNorm(
@@ -254,12 +266,15 @@ class Olmo2Model(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = make_layers(
config.num_hidden_layers,
@@ -267,7 +282,9 @@ class Olmo2Model(nn.Module):
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("layers", prefix),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.model = Olmo2Model(config, quant_config)
self.model = Olmo2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
@@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module):
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)