Update model_loader deps and qqq quantization deps (#2220) (#2318)

Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
Yineng Zhang
2024-12-02 23:22:13 +08:00
committed by GitHub
parent 33deca81b5
commit 85e1a6f3aa
58 changed files with 2363 additions and 366 deletions

View File

@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class GPTBigCodeAttention(nn.Module):
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
self,
layer_id: int,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
self,
layer_id: int,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
lora_vocab = 0
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[
GPTBigCodeBlock(i, config, cache_config, quant_config)
GPTBigCodeBlock(i, config, quant_config)
for i in range(config.num_hidden_layers)
]
)
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(
config, cache_config, quant_config, lora_config
)
self.transformer = GPTBigCodeModel(config, quant_config)
self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()