diff --git a/README.md b/README.md index 17720f2a0..ab3f51f09 100644 --- a/README.md +++ b/README.md @@ -316,6 +316,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port - Mixtral - LLaVA - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000` +- Qwen - AWQ quantization ## Benchmark And Performance diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 41bc23b43..fb569890e 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -61,7 +61,6 @@ class RadixAttention(nn.Module): def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): o = torch.empty_like(q) self.store_kv_cache(k, v, input_metadata) - extend_attention_fwd( q.view(-1, self.tp_q_head_num, self.head_dim), k.contiguous(), diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 27fdeb749..b572585af 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -55,6 +55,7 @@ class DetokenizerManager: first_token = self.tokenizer.convert_ids_to_tokens( int(output_tokens[i][0]) ) + first_token = first_token.decode("utf-8") if first_token.startswith("▁"): output_strs[i] = " " + output_strs[i] diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 071ec4efe..d08796e4d 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -240,6 +240,7 @@ class ModelRunner: from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llava import LlavaLlamaForCausalLM from sglang.srt.models.mixtral import MixtralForCausalLM + from sglang.srt.models.qwen import QWenLMHeadModel # Select model class architectures = getattr(self.model_config.hf_config, "architectures", []) @@ -258,6 +259,9 @@ class ModelRunner: if arch == "MixtralForCausalLM": model_class = MixtralForCausalLM break + if arch == "QWenLMHeadModel": + model_class = QWenLMHeadModel + break if model_class is None: raise ValueError(f"Unsupported architectures: {architectures}") diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 9c7b41e8c..5f8aa50ce 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -20,8 +20,10 @@ class ModelConfig: # Unify the config keys for hf_config self.context_len = get_context_length(self.hf_config) self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads - self.num_key_value_heads = self.hf_config.num_key_value_heads self.num_attention_heads = self.hf_config.num_attention_heads + self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads self.hidden_size = self.hf_config.hidden_size self.num_hidden_layers = self.hf_config.num_hidden_layers self.vocab_size = self.hf_config.vocab_size diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py new file mode 100644 index 000000000..e89bfb48c --- /dev/null +++ b/python/sglang/srt/models/qwen.py @@ -0,0 +1,261 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.router.model_runner import InputMetadata +from torch import nn +from vllm.transformers_utils.configs.qwen import QWenConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + +class QWenMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + 2 * [intermediate_size], + bias=False, + gather_output=False, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + +class QWenAttention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + + # pylint: disable=invalid-name + self.c_attn = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + bias=True + ) + self.c_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.scaling = self.head_dim**-0.5 + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.c_proj(attn_output) + return output + +class QWenBlock(nn.Module): + + def __init__(self, config: QWenConfig,layer_id): + super().__init__() + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + self.attn = QWenAttention(config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + layer_id=layer_id) + + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + +class QWenModel(nn.Module): + + def __init__(self, config:QWenConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding( + vocab_size, + config.hidden_size, + ) + self.h = nn.ModuleList( + [QWenBlock(config, i) for i in range(config.num_hidden_layers)]) + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer( + positions, + hidden_states, + input_metadata, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + +class QWenLMHeadModel(nn.Module): + + def __init__(self, config: QWenConfig,linear_method=None): + super().__init__() + self.config = config + self.transformer = QWenModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.lm_head = ParallelLMHead( + vocab_size, + config.hidden_size + ) + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata + ): + hidden_states = self.transformer(input_ids, positions,input_metadata) + next_tokens = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + return next_tokens + + _column_parallel_weights = [] + _row_parallel_weights = ["c_proj.weight"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w2", 0), + ("gate_up_proj", "w1", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 54e14fe67..6822e9521 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -108,9 +108,11 @@ def get_exception_traceback(): def get_int_token_logit_bias(tokenizer, vocab_size): from transformers import LlamaTokenizer, LlamaTokenizerFast + # a bug when model's vocab size > tokenizer.vocab_size + vocab_size = tokenizer.vocab_size logit_bias = np.zeros(vocab_size, dtype=np.float32) for t_id in range(vocab_size): - ss = tokenizer.decode(t_id).strip() + ss = tokenizer.decode([t_id]).strip() if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): logit_bias[t_id] = -1e5 # else: @@ -214,4 +216,4 @@ def load_image(image_file): else: image = Image.open(BytesIO(base64.b64decode(image_file))) - return image + return image \ No newline at end of file