[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
class Olmo2Attention(nn.Module):
|
||||
@@ -60,6 +60,7 @@ class Olmo2Attention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
@@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
# Attention output projection.
|
||||
@@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module):
|
||||
self.head_dim * self.total_num_heads,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
def _apply_qk_norm(
|
||||
@@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
@@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = Olmo2Attention(config, layer_id, quant_config)
|
||||
self.self_attn = Olmo2Attention(
|
||||
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
|
||||
)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = Olmo2MLP(config, quant_config)
|
||||
self.mlp = Olmo2MLP(config, quant_config, prefix=add_prefix("mlp", prefix))
|
||||
|
||||
# RMSNorm
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -254,12 +266,15 @@ class Olmo2Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -267,7 +282,9 @@ class Olmo2Model(nn.Module):
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Olmo2Model(config, quant_config)
|
||||
self.model = Olmo2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
@@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user