diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index d1074fb5a..3b89fa5b0 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -193,6 +193,7 @@ def match_chat_ml(model_path: str): if "qwen" in model_path and "chat" in model_path: return get_chat_template("chatml") + @register_chat_template_matching_function def match_chat_yi(model_path: str): model_path = model_path.lower() diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 10c11f659..980a2cd20 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module): torch.arange(all_logprobs.shape[0], device="cuda"), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), ] - logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32) + logprobs_cumsum = torch.cumsum( + prefill_logprobs, dim=0, dtype=torch.float32 + ) start = input_metadata.extend_start_loc.clone() end = start + input_metadata.extend_seq_lens - 2 start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) - sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start] + sum_logp = ( + logprobs_cumsum[end] + - logprobs_cumsum[start] + + prefill_logprobs[start] + ) normalized_logprobs = sum_logp / ( (input_metadata.extend_seq_lens - 1).clamp(min=1) ) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index dabb515ed..9857ff4d3 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( class RadixAttention(nn.Module): - def __init__( - self, - num_heads, - head_dim, - scaling, - num_kv_heads, - layer_id - ): + def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id): super().__init__() self.tp_q_head_num = num_heads self.tp_k_head_num = num_kv_heads diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 39748c691..a2070a5b1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -100,6 +100,7 @@ class BatchStrOut: class FlushCacheReq: pass + @dataclass class DetokenizeReqInput: input_ids: List[int] diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index fe61ce8ca..444e2e872 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -11,8 +11,8 @@ import rpyc import torch from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer -from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.fsm_cache import FSMCache +from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import ( BatchTokenIDOut, @@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service): logprobs = None if batch.extend_num_tokens != 0: # Forward - logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = ( - self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob) + logits, ( + prefill_logprobs, + normalized_logprobs, + last_logprobs, + ) = self.model_runner.forward( + batch, ForwardMode.EXTEND, batch.return_logprob ) if prefill_logprobs is not None: logprobs = prefill_logprobs.cpu().tolist() @@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service): # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. reqs = batch.reqs if last_logprobs is not None: - last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() + last_logprobs = ( + last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() + ) # Check finish condition pt = 0 @@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service): # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. reqs = batch.reqs if last_logprobs is not None: - last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist() + last_logprobs = last_logprobs[ + torch.arange(len(reqs)), next_token_ids + ].tolist() # Check finish condition for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)): @@ -620,15 +628,16 @@ class ModelRpcClient: self.step = async_wrap("step") -def start_model_process(port): - def _init_service(port): - t = ThreadedServer( - ModelRpcServer(), - port=port, - protocol_config={"allow_pickle": True, "sync_request_timeout": 1800}, - ) - t.start() +def _init_service(port): + t = ThreadedServer( + ModelRpcServer(), + port=port, + protocol_config={"allow_pickle": True, "sync_request_timeout": 1800}, + ) + t.start() + +def start_model_process(port): proc = multiprocessing.Process(target=_init_service, args=(port,)) proc.start() time.sleep(1) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 1dd2180e8..a68e99b1c 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel import sglang -QUANTIONCONFIG_MAPPING = {'awq': AWQConfig, - 'gptq': GPTQConfig} + +QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig} logger = logging.getLogger("model_runner") @@ -283,9 +283,13 @@ class ModelRunner: self.model_config.hf_config, "quantization_config", None ) if hf_quant_config is not None: - quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method']) + quant_config_class = QUANTIONCONFIG_MAPPING.get( + hf_quant_config["quant_method"] + ) if quant_config_class is None: - raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}") + raise ValueError( + f"Unsupported quantization method: {hf_quant_config['quant_method']}" + ) quant_config = quant_config_class.from_config(hf_quant_config) logger.info(f"quant_config: {quant_config}") linear_method = quant_config.get_linear_method() diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 3089c7ac0..114aec717 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -42,14 +42,14 @@ class QWenMLP(nn.Module): 2 * [intermediate_size], bias=False, gather_output=False, - linear_method=linear_method + linear_method=linear_method, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, input_is_parallel=True, - linear_method=linear_method + linear_method=linear_method, ) if hidden_act != "silu": raise ValueError( @@ -74,7 +74,7 @@ class QWenAttention(nn.Module): layer_id: int = 0, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, - linear_method: Optional[LinearMethodBase] = None + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -86,18 +86,18 @@ class QWenAttention(nn.Module): # pylint: disable=invalid-name self.c_attn = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - bias=True, - linear_method=linear_method + hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - linear_method=linear_method + linear_method=linear_method, ) self.rotary_emb = get_rope( self.head_dim, @@ -143,12 +143,16 @@ class QWenBlock(nn.Module): rope_theta=rope_theta, rope_scaling=rope_scaling, layer_id=layer_id, - linear_method=linear_method + linear_method=linear_method, ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method) + self.mlp = QWenMLP( + config.hidden_size, + config.intermediate_size // 2, + linear_method=linear_method, + ) def forward( self, @@ -186,7 +190,10 @@ class QWenModel(nn.Module): config.hidden_size, ) self.h = nn.ModuleList( - [QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)] + [ + QWenBlock(config, i, linear_method=linear_method) + for i in range(config.num_hidden_layers) + ] ) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index d86a4a8ba..0d7a29dc7 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -4,14 +4,17 @@ from typing import List, Optional import torch import torch.nn as nn +from sglang.srt.models.llava import ( + LlavaLlamaForCausalLM, + clip_vision_embed_forward, + monkey_path_clip_vision_embed_forward, +) from transformers import CLIPVisionModel, LlavaConfig from vllm.model_executor.weight_utils import ( default_weight_loader, hf_model_weights_iterator, ) -from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward - class YiVLForCausalLM(LlavaLlamaForCausalLM): def __init__(self, *args, **kwargs): @@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): super().__init__(self.config) self.multi_modal_projector = YiVLMultiModalProjector(self.config) - self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./" + self.vision_tower_subfolder = self.config.mm_vision_tower.replace( + "./", "" + ) # Everything after "./" def load_weights( self, @@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ): # We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B) self.vision_tower = CLIPVisionModel.from_pretrained( - model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder + model_name_or_path, + torch_dtype=torch.float16, + subfolder=self.vision_tower_subfolder, ).cuda() self.vision_tower.eval() @@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): monkey_path_clip_vision_embed_forward() + class YiVLMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): super().__init__() - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size + ) self.ln_1 = nn.LayerNorm(config.text_config.hidden_size) self.act = nn.GELU() - self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size + ) self.ln_2 = nn.LayerNorm(config.text_config.hidden_size) def forward(self, image_features): @@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module): hidden_states = self.ln_2(hidden_states) return hidden_states -EntryClass = YiVLForCausalLM \ No newline at end of file + +EntryClass = YiVLForCausalLM diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d92ee54b2..00fa03ece 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -63,6 +63,7 @@ chat_template_name = None # FIXME: Remove this once we drop support for pydantic 1.x IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 + def jsonify_pydantic_model(obj: BaseModel): if IS_PYDANTIC_1: return obj.json(ensure_ascii=False) @@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request): prompt_tokens = content["meta_info"]["prompt_tokens"] completion_tokens = content["meta_info"]["completion_tokens"] - if not stream_buffer: # The first chunk + if not stream_buffer: # The first chunk if request.echo: # Prepend prompt in response text. text = request.prompt + text @@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request): token_logprob_pos = prompt_tokens logprobs = ( - await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:]) + await make_openai_style_logprobs( + ret["meta_info"]["token_logprob"][token_logprob_pos:] + ) if request.logprobs is not None else None ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f33f86eb6..bcc29b782 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -114,7 +114,7 @@ class ServerArgs: "--max-prefill-num-token", type=int, default=ServerArgs.max_prefill_num_token, - help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length." + help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", ) parser.add_argument( "--tp-size", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 42c534abe..54274f366 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -259,4 +259,4 @@ def load_image(image_file): else: image = Image.open(BytesIO(base64.b64decode(image_file))) - return image \ No newline at end of file + return image diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index 7462661b7..fac301f6a 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -12,6 +12,7 @@ import argparse import requests + def test_decode(url, return_logprob): response = requests.post( url + "/generate", @@ -27,6 +28,7 @@ def test_decode(url, return_logprob): ) print(response.json()) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="http://127.0.0.1") diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index 5713c6380..a6d29c548 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -12,6 +12,7 @@ import json import requests + def test_decode_stream(url, return_logprob): response = requests.post( url + "/generate", @@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob): assert data["meta_info"]["prompt_logprob"] is not None assert data["meta_info"]["token_logprob"] is not None assert data["meta_info"]["normalized_prompt_logprob"] is not None - if prev == 0: # Skip prompt logprobs + if prev == 0: # Skip prompt logprobs prev = data["meta_info"]["prompt_tokens"] for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]: print(f"{token_txt}\t{logprob}", flush=True) @@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob): prev = len(output) print("") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="http://127.0.0.1") diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index b1351360d..01aa53e5b 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -64,9 +64,8 @@ def test_completion_stream(args, echo, logprobs): first = False if logprobs: print( - f"{r.choices[0].text:12s}\t" - f"{r.choices[0].logprobs.token_logprobs}", - flush=True + f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}", + flush=True, ) else: print(r.choices[0].text, end="", flush=True)