Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTBigCodeConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
|
||||
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
|
||||
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
lora_vocab = 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.wte = VocabParallelEmbedding(
|
||||
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
||||
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPTBigCodeBlock(i, config, cache_config, quant_config)
|
||||
GPTBigCodeBlock(i, config, quant_config)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(
|
||||
config, cache_config, quant_config, lora_config
|
||||
)
|
||||
self.transformer = GPTBigCodeModel(config, quant_config)
|
||||
self.lm_head = self.transformer.wte
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user