From a8c787d2b316c1672d9c626e38496066c71d8adb Mon Sep 17 00:00:00 2001 From: Qubitium <417764+Qubitium@users.noreply.github.com> Date: Wed, 12 Jun 2024 07:39:52 +0800 Subject: [PATCH] Add ChatGLM Model Support (#516) Co-authored-by: ZX --- .../srt/managers/controller/model_runner.py | 11 +- python/sglang/srt/model_config.py | 70 +++- python/sglang/srt/models/chatglm.py | 390 ++++++++++++++++++ 3 files changed, 468 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/models/chatglm.py diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index bc622208d..991700bc9 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -330,7 +330,7 @@ class ModelRunner: self.token_to_kv_pool = TokenToKVPool( self.max_total_num_tokens, dtype=torch.float16, - head_num=self.model_config.num_key_value_heads // self.tp_size, + head_num=self.model_config.get_num_kv_heads(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, ) @@ -446,11 +446,20 @@ def import_model_classes(): model_arch_name_to_cls[tmp.__name__] = tmp else: model_arch_name_to_cls[entry.__name__] = entry + + # compat: some models such as chatglm has incorrect class set in config.json + # usage: [ tuple("From_Entry_Class_Name": EntryClass), ] + if hasattr(module, "EntryClassRemapping") and isinstance(module.EntryClassRemapping, list): + for remap in module.EntryClassRemapping: + if isinstance(remap, tuple) and len(remap) == 2: + model_arch_name_to_cls[remap[0]] = remap[1] + return model_arch_name_to_cls def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: model_arch_name_to_cls = import_model_classes() + if model_arch not in model_arch_name_to_cls: raise ValueError( f"Unsupported architectures: {model_arch}. " diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index dfeac0a25..3c0062bae 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -1,6 +1,7 @@ from typing import Optional from sglang.srt.hf_transformers_utils import get_config, get_context_length +from transformers import PretrainedConfig class ModelConfig: @@ -18,7 +19,7 @@ class ModelConfig: self.model_overide_args = model_overide_args self.hf_config = get_config(self.path, trust_remote_code, revision, model_overide_args=model_overide_args) - + self.hf_text_config = get_hf_text_config(self.hf_config) if context_length is not None: self.context_len = context_length else: @@ -43,4 +44,69 @@ class ModelConfig: self.num_key_value_heads = self.num_attention_heads self.hidden_size = self.hf_config.hidden_size self.num_hidden_layers = self.hf_config.num_hidden_layers - self.vocab_size = self.hf_config.vocab_size \ No newline at end of file + self.vocab_size = self.hf_config.vocab_size + + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type in ["dbrx", "mpt"]: + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328 + def get_num_kv_heads(self, tensor_parallel_size) -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // tensor_parallel_size) + + +def get_hf_text_config(config: PretrainedConfig): + """Get the "sub" config relevant to llm for multi modal models. + No op for pure text models. + """ + if hasattr(config, "text_config"): + # The code operates under the assumption that text_config should have + # `num_attention_heads` (among others). Assert here to fail early + # if transformers config doesn't align with this assumption. + assert hasattr(config.text_config, "num_attention_heads") + return config.text_config + else: + return config diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py new file mode 100644 index 000000000..83c0ef750 --- /dev/null +++ b/python/sglang/srt/models/chatglm.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +"""Inference-only ChatGLM model compatible with THUDM weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from peft import LoraConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.layers.logits_processor import LogitsProcessor +from torch import nn +from torch.nn import LayerNorm + +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs import ChatGLMConfig + + + +class GLMAttention(nn.Module): + + def __init__( + self, + config, + layer_id: int = 0, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.multi_query_attention = config.multi_query_attention + self.total_num_kv_heads = (config.multi_query_group_num + if config.multi_query_attention else + config.num_attention_heads) + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.add_bias_linear or config.add_qkv_bias, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + ) + + # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 + rope_ratio = getattr(config, "rope_ratio", 1.0) + max_positions = getattr(config, "seq_length", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim // 2, + max_position=max_positions, + base=10000 * rope_ratio, + is_neox_style=False, + ) + self.attn = RadixAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + context_layer = self.attn( + q, + k, + v, + input_metadata, + ) + attn_output, _ = self.dense(context_layer) + return attn_output + + +class GLMMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. + self.dense_h_to_4h = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=config.add_bias_linear, + quant_config=quant_config, + ) + + self.activation_func = SiluAndMul() + + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output, _ = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config, + layer_id: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = layer_norm_func(config.hidden_size, + eps=config.layernorm_epsilon) + + # Self attention. + self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GLMMLP(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + # hidden_states: [num_tokens, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.self_attention( + hidden_states=layernorm_output, + position_ids=position_ids, + input_metadata=input_metadata, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = residual + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.mlp(layernorm_output) + residual + + return output + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.layers = nn.ModuleList([ + GLMBlock(config, i, cache_config, quant_config) + for i in range(self.num_layers) + ]) + + if self.post_layer_norm: + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + for i in range(self.num_layers): + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + position_ids=position_ids, + input_metadata=input_metadata, + ) + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class ChatGLMModel(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, + config.hidden_size) + + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.encoder = GLMTransformer(config, cache_config, quant_config) + + self.output_layer = ParallelLMHead(config.padded_vocab_size, + config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.embedding(input_ids) + + # Run encoder. + hidden_states = self.encoder( + hidden_states=inputs_embeds, + position_ids=position_ids, + input_metadata=input_metadata, + ) + return hidden_states + + +class ChatGLMForCausalLM(nn.Module): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoraConfig] = None, + ): + super().__init__() + self.config: ChatGLMConfig = config + self.quant_config = quant_config + self.max_position_embeddings = getattr(config, "max_sequence_length", + 8192) + self.transformer = ChatGLMModel(config, cache_config, quant_config) + self.lm_head = self.transformer.output_layer + self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, + input_metadata) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_pos_emb.inv_freq" in name: + continue + if "word_embeddings" in name: + name = name.replace(".word_embeddings", "") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + +EntryClass = ChatGLMForCausalLM +# compat: glm model.config class == ChatGLMModel +EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)] \ No newline at end of file