From 94e05770db538cadce18f5c201572067ab87840e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 22 Jan 2024 21:17:05 -0800 Subject: [PATCH] Fix after QWen support (#82) --- python/sglang/lang/chat_template.py | 5 +- .../srt/managers/detokenizer_manager.py | 3 +- python/sglang/srt/models/qwen.py | 79 +++++++++---------- python/sglang/srt/utils.py | 2 +- 4 files changed, 46 insertions(+), 43 deletions(-) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 579cc845b..5ea9786b8 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -168,7 +168,10 @@ def match_llama2_chat(model_path: str): @register_chat_template_matching_function def match_chat_ml(model_path: str): - if "tinyllama" in model_path.lower(): + model_path = model_path.lower() + if "tinyllama" in model_path: + return get_chat_template("chatml") + if "qwen" in model_path and "chat" in model_path: return get_chat_template("chatml") diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index b572585af..6dd3d79a7 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -55,7 +55,8 @@ class DetokenizerManager: first_token = self.tokenizer.convert_ids_to_tokens( int(output_tokens[i][0]) ) - first_token = first_token.decode("utf-8") + if not isinstance(first_token, str): + first_token = first_token.decode("utf-8") if first_token.startswith("▁"): output_strs[i] = " " + output_strs[i] diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index e89bfb48c..ba59d5bb6 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -5,7 +5,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata from torch import nn -from vllm.transformers_utils.configs.qwen import QWenConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -26,9 +25,10 @@ from vllm.model_executor.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) +from vllm.transformers_utils.configs.qwen import QWenConfig + class QWenMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -49,8 +49,10 @@ class QWenMLP(nn.Module): input_is_parallel=True, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -59,31 +61,28 @@ class QWenMLP(nn.Module): x, _ = self.c_proj(x) return x -class QWenAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - max_position_embeddings: int, - layer_id: int = 0, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None): +class QWenAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads # pylint: disable=invalid-name self.c_attn = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - bias=True + hidden_size, self.head_dim, self.total_num_heads, bias=True ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -120,20 +119,22 @@ class QWenAttention(nn.Module): output, _ = self.c_proj(attn_output) return output -class QWenBlock(nn.Module): - def __init__(self, config: QWenConfig,layer_id): +class QWenBlock(nn.Module): + def __init__(self, config: QWenConfig, layer_id): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - self.attn = QWenAttention(config.hidden_size, - config.num_attention_heads, - config.max_position_embeddings, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - layer_id=layer_id) + self.attn = QWenAttention( + config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + layer_id=layer_id, + ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -161,10 +162,10 @@ class QWenBlock(nn.Module): hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states - -class QWenModel(nn.Module): - def __init__(self, config:QWenConfig): + +class QWenModel(nn.Module): + def __init__(self, config: QWenConfig): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -175,7 +176,8 @@ class QWenModel(nn.Module): config.hidden_size, ) self.h = nn.ModuleList( - [QWenBlock(config, i) for i in range(config.num_hidden_layers)]) + [QWenBlock(config, i) for i in range(config.num_hidden_layers)] + ) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( @@ -195,26 +197,23 @@ class QWenModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states -class QWenLMHeadModel(nn.Module): - def __init__(self, config: QWenConfig,linear_method=None): +class QWenLMHeadModel(nn.Module): + def __init__(self, config: QWenConfig, linear_method=None): super().__init__() self.config = config self.transformer = QWenModel(config) vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ParallelLMHead( - vocab_size, - config.hidden_size - ) + self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata + input_metadata: InputMetadata, ): - hidden_states = self.transformer(input_ids, positions,input_metadata) + hidden_states = self.transformer(input_ids, positions, input_metadata) next_tokens = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6822e9521..8c5876602 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -216,4 +216,4 @@ def load_image(image_file): else: image = Image.open(BytesIO(base64.b64decode(image_file))) - return image \ No newline at end of file + return image