[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -35,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@@ -44,6 +45,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -69,6 +71,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_attn", prefix),
|
||||
)
|
||||
|
||||
self.c_proj = RowParallelLinear(
|
||||
@@ -76,6 +79,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -83,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
scaling=self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -111,6 +116,7 @@ class GPTBigMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@@ -119,12 +125,14 @@ class GPTBigMLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_fc", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
self.act = get_act_fn(
|
||||
config.activation_function, quant_config, intermediate_size
|
||||
@@ -144,15 +152,20 @@ class GPTBigCodeBlock(nn.Module):
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
|
||||
self.attn = GPTBigCodeAttention(
|
||||
layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
self.mlp = GPTBigMLP(
|
||||
inner_dim, config, quant_config, prefix=add_prefix("mlp", prefix)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -181,6 +194,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -190,12 +204,17 @@ class GPTBigCodeModel(nn.Module):
|
||||
lora_vocab = 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.wte = VocabParallelEmbedding(
|
||||
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
||||
self.vocab_size,
|
||||
self.embed_dim,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("wte", prefix),
|
||||
)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPTBigCodeBlock(i, config, quant_config)
|
||||
GPTBigCodeBlock(
|
||||
i, config, quant_config, prefix=add_prefix(f"h.{i}", prefix)
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -235,13 +254,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(config, quant_config)
|
||||
self.transformer = GPTBigCodeModel(
|
||||
config, quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = self.transformer.wte
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
Reference in New Issue
Block a user