@@ -330,7 +330,7 @@ class ModelRunner:
|
|||||||
self.token_to_kv_pool = TokenToKVPool(
|
self.token_to_kv_pool = TokenToKVPool(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
)
|
)
|
||||||
@@ -446,11 +446,20 @@ def import_model_classes():
|
|||||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||||
else:
|
else:
|
||||||
model_arch_name_to_cls[entry.__name__] = entry
|
model_arch_name_to_cls[entry.__name__] = entry
|
||||||
|
|
||||||
|
# compat: some models such as chatglm has incorrect class set in config.json
|
||||||
|
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
||||||
|
if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list):
|
||||||
|
for remap in module.EntryClassRemapping:
|
||||||
|
if isinstance(remap, tuple) and len(remap) == 2:
|
||||||
|
model_arch_name_to_cls[remap[0]] = remap[1]
|
||||||
|
|
||||||
return model_arch_name_to_cls
|
return model_arch_name_to_cls
|
||||||
|
|
||||||
|
|
||||||
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||||
model_arch_name_to_cls = import_model_classes()
|
model_arch_name_to_cls = import_model_classes()
|
||||||
|
|
||||||
if model_arch not in model_arch_name_to_cls:
|
if model_arch not in model_arch_name_to_cls:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported architectures: {model_arch}. "
|
f"Unsupported architectures: {model_arch}. "
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@@ -18,7 +19,7 @@ class ModelConfig:
|
|||||||
self.model_overide_args = model_overide_args
|
self.model_overide_args = model_overide_args
|
||||||
self.hf_config = get_config(self.path, trust_remote_code, revision,
|
self.hf_config = get_config(self.path, trust_remote_code, revision,
|
||||||
model_overide_args=model_overide_args)
|
model_overide_args=model_overide_args)
|
||||||
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
self.context_len = context_length
|
self.context_len = context_length
|
||||||
else:
|
else:
|
||||||
@@ -43,4 +44,69 @@ class ModelConfig:
|
|||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
self.hidden_size = self.hf_config.hidden_size
|
self.hidden_size = self.hf_config.hidden_size
|
||||||
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
self.vocab_size = self.hf_config.vocab_size
|
self.vocab_size = self.hf_config.vocab_size
|
||||||
|
|
||||||
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||||
|
def get_total_num_kv_heads(self) -> int:
|
||||||
|
"""Returns the total number of KV heads."""
|
||||||
|
# For GPTBigCode & Falcon:
|
||||||
|
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||||
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
|
# KV heads.
|
||||||
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||||
|
new_decoder_arch_falcon = (
|
||||||
|
self.hf_config.model_type in falcon_model_types
|
||||||
|
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||||
|
if not new_decoder_arch_falcon and getattr(self.hf_text_config,
|
||||||
|
"multi_query", False):
|
||||||
|
# Multi-query attention, only one KV head.
|
||||||
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# For DBRX and MPT
|
||||||
|
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||||
|
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
||||||
|
self.hf_config.num_attention_heads)
|
||||||
|
|
||||||
|
attributes = [
|
||||||
|
# For Falcon:
|
||||||
|
"n_head_kv",
|
||||||
|
"num_kv_heads",
|
||||||
|
# For LLaMA-2:
|
||||||
|
"num_key_value_heads",
|
||||||
|
# For ChatGLM:
|
||||||
|
"multi_query_group_num",
|
||||||
|
]
|
||||||
|
for attr in attributes:
|
||||||
|
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
||||||
|
if num_kv_heads is not None:
|
||||||
|
return num_kv_heads
|
||||||
|
|
||||||
|
# For non-grouped-query attention models, the number of KV heads is
|
||||||
|
# equal to the number of attention heads.
|
||||||
|
return self.hf_text_config.num_attention_heads
|
||||||
|
|
||||||
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
||||||
|
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
||||||
|
"""Returns the number of KV heads per GPU."""
|
||||||
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||||
|
# If tensor parallelism is used, we divide the number of KV heads by
|
||||||
|
# the tensor parallel size. We will replicate the KV heads in the
|
||||||
|
# case where the number of KV heads is smaller than the tensor
|
||||||
|
# parallel size so each GPU has at least one KV head.
|
||||||
|
return max(1,
|
||||||
|
total_num_kv_heads // tensor_parallel_size)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_text_config(config: PretrainedConfig):
|
||||||
|
"""Get the "sub" config relevant to llm for multi modal models.
|
||||||
|
No op for pure text models.
|
||||||
|
"""
|
||||||
|
if hasattr(config, "text_config"):
|
||||||
|
# The code operates under the assumption that text_config should have
|
||||||
|
# `num_attention_heads` (among others). Assert here to fail early
|
||||||
|
# if transformers config doesn't align with this assumption.
|
||||||
|
assert hasattr(config.text_config, "num_attention_heads")
|
||||||
|
return config.text_config
|
||||||
|
else:
|
||||||
|
return config
|
||||||
|
|||||||
390
python/sglang/srt/models/chatglm.py
Normal file
390
python/sglang/srt/models/chatglm.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/THUDM/ChatGLM2-6B
|
||||||
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GLMAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
layer_id: int = 0,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.multi_query_attention = config.multi_query_attention
|
||||||
|
self.total_num_kv_heads = (config.multi_query_group_num
|
||||||
|
if config.multi_query_attention else
|
||||||
|
config.num_attention_heads)
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = config.hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.query_key_value = QKVParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.dense = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=config.add_bias_linear,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||||
|
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
||||||
|
max_positions = getattr(config, "seq_length", 8192)
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim // 2,
|
||||||
|
max_position=max_positions,
|
||||||
|
base=10000 * rope_ratio,
|
||||||
|
is_neox_style=False,
|
||||||
|
)
|
||||||
|
self.attn = RadixAttention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
layer_id=layer_id)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
|
context_layer = self.attn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
input_metadata,
|
||||||
|
)
|
||||||
|
attn_output, _ = self.dense(context_layer)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class GLMMLP(nn.Module):
|
||||||
|
"""MLP.
|
||||||
|
|
||||||
|
MLP will take the input with h hidden state, project it to 4*h
|
||||||
|
hidden dimension, perform nonlinear transformation, and project the
|
||||||
|
state back into h hidden dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.add_bias = config.add_bias_linear
|
||||||
|
|
||||||
|
# Project to 4h.
|
||||||
|
self.dense_h_to_4h = MergedColumnParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
[config.ffn_hidden_size] * 2,
|
||||||
|
bias=config.add_bias_linear,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation_func = SiluAndMul()
|
||||||
|
|
||||||
|
# Project back to h.
|
||||||
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
|
config.ffn_hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=config.add_bias_linear,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# [s, b, 4hp]
|
||||||
|
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
||||||
|
intermediate_parallel = self.activation_func(intermediate_parallel)
|
||||||
|
# [s, b, h]
|
||||||
|
output, _ = self.dense_4h_to_h(intermediate_parallel)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GLMBlock(nn.Module):
|
||||||
|
"""A single transformer layer.
|
||||||
|
|
||||||
|
Transformer layer takes input with size [s, b, h] and returns an
|
||||||
|
output of the same size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
layer_id: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.apply_residual_connection_post_layernorm = (
|
||||||
|
config.apply_residual_connection_post_layernorm)
|
||||||
|
|
||||||
|
self.fp32_residual_connection = config.fp32_residual_connection
|
||||||
|
|
||||||
|
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||||
|
# Layernorm on the input data.
|
||||||
|
self.input_layernorm = layer_norm_func(config.hidden_size,
|
||||||
|
eps=config.layernorm_epsilon)
|
||||||
|
|
||||||
|
# Self attention.
|
||||||
|
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
|
||||||
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
|
||||||
|
# Layernorm on the attention output
|
||||||
|
self.post_attention_layernorm = layer_norm_func(
|
||||||
|
config.hidden_size, eps=config.layernorm_epsilon)
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
self.mlp = GLMMLP(config, quant_config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# hidden_states: [num_tokens, h]
|
||||||
|
# Layer norm at the beginning of the transformer layer.
|
||||||
|
layernorm_output = self.input_layernorm(hidden_states)
|
||||||
|
# Self attention.
|
||||||
|
attention_output = self.self_attention(
|
||||||
|
hidden_states=layernorm_output,
|
||||||
|
position_ids=position_ids,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Residual connection.
|
||||||
|
if self.apply_residual_connection_post_layernorm:
|
||||||
|
residual = layernorm_output
|
||||||
|
else:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
layernorm_input = residual + attention_output
|
||||||
|
|
||||||
|
# Layer norm post the self attention.
|
||||||
|
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||||
|
|
||||||
|
# Second residual connection.
|
||||||
|
if self.apply_residual_connection_post_layernorm:
|
||||||
|
residual = layernorm_output
|
||||||
|
else:
|
||||||
|
residual = layernorm_input
|
||||||
|
|
||||||
|
output = self.mlp(layernorm_output) + residual
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GLMTransformer(nn.Module):
|
||||||
|
"""Transformer class."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.post_layer_norm = config.post_layer_norm
|
||||||
|
|
||||||
|
# Number of layers.
|
||||||
|
self.num_layers = config.num_layers
|
||||||
|
|
||||||
|
# Transformer layers.
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
GLMBlock(config, i, cache_config, quant_config)
|
||||||
|
for i in range(self.num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
if self.post_layer_norm:
|
||||||
|
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||||
|
# Final layer norm before output.
|
||||||
|
self.final_layernorm = layer_norm_func(
|
||||||
|
config.hidden_size, eps=config.layernorm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
)
|
||||||
|
# Final layer norm.
|
||||||
|
if self.post_layer_norm:
|
||||||
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||||
|
config.hidden_size)
|
||||||
|
|
||||||
|
self.num_layers = config.num_layers
|
||||||
|
self.multi_query_group_num = config.multi_query_group_num
|
||||||
|
self.kv_channels = config.kv_channels
|
||||||
|
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||||
|
|
||||||
|
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||||
|
config.hidden_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
|
||||||
|
# Run encoder.
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
hidden_states=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMForCausalLM(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"query_key_value": ["query_key_value"],
|
||||||
|
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"query_key_value",
|
||||||
|
"dense",
|
||||||
|
"dense_h_to_4h",
|
||||||
|
"dense_4h_to_h",
|
||||||
|
]
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: ChatGLMConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoraConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config: ChatGLMConfig = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||||
|
8192)
|
||||||
|
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||||
|
self.lm_head = self.transformer.output_layer
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
|
input_metadata)
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_pos_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if "word_embeddings" in name:
|
||||||
|
name = name.replace(".word_embeddings", "")
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
EntryClass = ChatGLMForCausalLM
|
||||||
|
# compat: glm model.config class == ChatGLMModel
|
||||||
|
EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
|
||||||
Reference in New Issue
Block a user