[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class ExaoneGatedMLP(nn.Module):
|
||||
@@ -56,14 +57,14 @@ class ExaoneGatedMLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -130,14 +131,14 @@ class ExaoneAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
prefix=add_prefix("out_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -201,14 +202,14 @@ class ExaoneDecoderLayer(nn.Module):
|
||||
rope_is_neox_style=rope_is_neox_style,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = ExaoneGatedMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.activation_function,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
rms_norm_eps = config.layer_norm_epsilon
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
@@ -244,6 +245,7 @@ class ExaoneModel(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -256,7 +258,10 @@ class ExaoneModel(nn.Module):
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
ExaoneDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.h.{i}"
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"h.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -293,12 +298,17 @@ class ExaoneForCausalLM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = ExaoneModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.transformer = ExaoneModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user