From 690d162d9746e96d37cc62c5bf00d22f71c32583 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 14 May 2024 22:40:46 +0800 Subject: [PATCH] Format code (#441) --- benchmark/latency_throughput/test_latency.py | 2 +- python/sglang/backend/anthropic.py | 2 +- python/sglang/lang/interpreter.py | 6 ++- python/sglang/srt/flush_cache.py | 2 +- python/sglang/srt/managers/io_struct.py | 4 +- .../sglang/srt/managers/router/model_rpc.py | 20 +++++----- .../srt/managers/router/model_runner.py | 3 +- .../sglang/srt/managers/tokenizer_manager.py | 40 ++++++++++++------- python/sglang/srt/models/llava.py | 2 +- python/sglang/srt/models/llava_mistral.py | 7 +--- python/sglang/srt/models/llava_qwen.py | 7 +--- python/sglang/srt/openai_api_adapter.py | 5 ++- python/sglang/srt/openai_protocol.py | 3 +- python/sglang/srt/server.py | 9 +++-- python/sglang/srt/utils.py | 4 +- python/sglang/utils.py | 1 + test/srt/test_httpserver_reuse.py | 1 - 17 files changed, 68 insertions(+), 50 deletions(-) diff --git a/benchmark/latency_throughput/test_latency.py b/benchmark/latency_throughput/test_latency.py index 732ddf543..140b959ec 100644 --- a/benchmark/latency_throughput/test_latency.py +++ b/benchmark/latency_throughput/test_latency.py @@ -31,7 +31,7 @@ if __name__ == "__main__": url + "/generate", json={ "text": f"{a}, ", - #"input_ids": [[2] * 256] * 196, + # "input_ids": [[2] * 256] * 196, "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, diff --git a/python/sglang/backend/anthropic.py b/python/sglang/backend/anthropic.py index 330b2a412..d96d0f04f 100644 --- a/python/sglang/backend/anthropic.py +++ b/python/sglang/backend/anthropic.py @@ -74,4 +74,4 @@ class Anthropic(BaseBackend): **sampling_params.to_anthropic_kwargs(), ) as stream: for text in stream.text_stream: - yield text, {} \ No newline at end of file + yield text, {} diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index e33a9760b..cac2b714d 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -30,7 +30,11 @@ from sglang.lang.ir import ( SglVarScopeEnd, SglVideo, ) -from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback +from sglang.utils import ( + encode_image_base64, + encode_video_base64, + get_exception_traceback, +) def run_internal(state, program, func_args, func_kwargs, sync): diff --git a/python/sglang/srt/flush_cache.py b/python/sglang/srt/flush_cache.py index 3d695d44d..e962bb38b 100644 --- a/python/sglang/srt/flush_cache.py +++ b/python/sglang/srt/flush_cache.py @@ -13,4 +13,4 @@ if __name__ == "__main__": args = parser.parse_args() response = requests.get(args.url + "/flush_cache") - assert response.status_code == 200 \ No newline at end of file + assert response.status_code == 200 diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index db9655f64..a9f9ab2a1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -32,7 +32,9 @@ class GenerateReqInput: def post_init(self): if self.text is None: - assert self.input_ids is not None, "Either text or input_ids should be provided" + assert ( + self.input_ids is not None + ), "Either text or input_ids should be provided" else: assert self.input_ids is None, "Either text or input_ids should be provided" diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index a4d84a6c7..e9b57d23c 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -24,7 +24,7 @@ from sglang.srt.managers.io_struct import ( FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason +from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.router.scheduler import Scheduler @@ -37,7 +37,6 @@ from sglang.srt.utils import ( set_random_seed, ) - logger = logging.getLogger("model_rpc") vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN) @@ -238,7 +237,9 @@ class ModelRpcServer: self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic) + throuhgput = self.num_generated_tokens / ( + time.time() - self.last_stats_tic + ) self.num_generated_tokens = 0 self.last_stats_tic = time.time() logger.info( @@ -401,12 +402,12 @@ class ModelRpcServer: f"#running_req: {running_req}. " f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." ) - #logger.debug( + # logger.debug( # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " - #) + # ) new_batch = Batch.init_new( can_run_list, @@ -440,11 +441,10 @@ class ModelRpcServer: # Only transfer the selected logprobs of the next token to CPU to reduce overhead. if last_logprobs is not None: - last_token_logprobs = ( - last_logprobs[ - torch.arange(len(batch.reqs), device=next_token_ids.device), - next_token_ids].tolist() - ) + last_token_logprobs = last_logprobs[ + torch.arange(len(batch.reqs), device=next_token_ids.device), + next_token_ids, + ].tolist() next_token_ids = next_token_ids.tolist() else: diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 48541932b..02cc74714 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -17,8 +17,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool -from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory - +from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model QUANTIZATION_CONFIG_MAPPING = { "awq": AWQConfig, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 282466ea2..f4cd4ad86 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -72,7 +72,8 @@ def get_pixel_values( image_hash = hash(image_data) if image_aspect_ratio == "pad": image = expand2square( - image, tuple(int(x * 255) for x in processor.image_processor.image_mean) + image, + tuple(int(x * 255) for x in processor.image_processor.image_mean), ) pixel_values = processor.image_processor(image)["pixel_values"][0] elif image_aspect_ratio == "anyres": @@ -208,10 +209,12 @@ class TokenizerManager: while True: await event.wait() - out = self.convert_logprob_style(state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs) + out = self.convert_logprob_style( + state.out_list[-1], + obj.return_logprob, + obj.top_logprobs_num, + obj.return_text_in_logprobs, + ) if self.server_args.log_requests and state.finished: logger.info(f"in={obj.text}, out={out}") @@ -275,10 +278,13 @@ class TokenizerManager: state = self.rid_to_state[rid] await state.event.wait() output_list.append( - self.convert_logprob_style(state.out_list[-1], - obj.return_logprob[i], - obj.top_logprobs_num[i], - obj.return_text_in_logprobs)) + self.convert_logprob_style( + state.out_list[-1], + obj.return_logprob[i], + obj.top_logprobs_num[i], + obj.return_text_in_logprobs, + ) + ) assert state.finished del self.rid_to_state[rid] @@ -311,7 +317,9 @@ class TokenizerManager: else: raise ValueError(f"Invalid object: {recv_obj}") - def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs): + def convert_logprob_style( + self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs + ): if return_logprob: ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs @@ -320,11 +328,15 @@ class TokenizerManager: ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ) if top_logprobs_num > 0: - ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ret["meta_info"]["prefill_top_logprobs"] = ( + self.detokenize_top_logprobs_tokens( + ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ) ) - ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ret["meta_info"]["decode_top_logprobs"] = ( + self.detokenize_top_logprobs_tokens( + ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ) ) return ret diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index d423541dd..abce92061 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward(): ) -EntryClass = LlavaLlamaForCausalLM \ No newline at end of file +EntryClass = LlavaLlamaForCausalLM diff --git a/python/sglang/srt/models/llava_mistral.py b/python/sglang/srt/models/llava_mistral.py index d14617cbd..2a42e2b5e 100644 --- a/python/sglang/srt/models/llava_mistral.py +++ b/python/sglang/srt/models/llava_mistral.py @@ -5,13 +5,9 @@ from typing import List, Optional import numpy as np import torch from torch import nn -from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig +from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.model_runner import InputMetadata @@ -21,6 +17,7 @@ from sglang.srt.mm_utils import ( unpad_image_shape, ) from sglang.srt.models.mistral import MistralForCausalLM +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlavaMistralForCausalLM(nn.Module): diff --git a/python/sglang/srt/models/llava_qwen.py b/python/sglang/srt/models/llava_qwen.py index c73ba7e95..2c60c5ef9 100644 --- a/python/sglang/srt/models/llava_qwen.py +++ b/python/sglang/srt/models/llava_qwen.py @@ -5,13 +5,9 @@ from typing import List, Optional import numpy as np import torch from torch import nn -from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config +from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.model_runner import InputMetadata @@ -21,6 +17,7 @@ from sglang.srt.mm_utils import ( unpad_image_shape, ) from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlavaQwenForCausalLM(nn.Module): diff --git a/python/sglang/srt/openai_api_adapter.py b/python/sglang/srt/openai_api_adapter.py index ced3eaa23..9d4b87acb 100644 --- a/python/sglang/srt/openai_api_adapter.py +++ b/python/sglang/srt/openai_api_adapter.py @@ -1,4 +1,5 @@ """Conversion between OpenAI APIs and native SRT APIs""" + import json import os @@ -31,9 +32,9 @@ from sglang.srt.openai_protocol import ( ) from sglang.srt.utils import jsonify_pydantic_model - chat_template_name = None + def load_chat_template_for_openai_api(chat_template_arg): global chat_template_name @@ -353,4 +354,4 @@ def to_openai_style_logprobs( if decode_top_logprobs is not None: append_top_logprobs(decode_top_logprobs) - return ret_logprobs \ No newline at end of file + return ret_logprobs diff --git a/python/sglang/srt/openai_protocol.py b/python/sglang/srt/openai_protocol.py index 0484de529..ac88b2dd5 100644 --- a/python/sglang/srt/openai_protocol.py +++ b/python/sglang/srt/openai_protocol.py @@ -1,4 +1,5 @@ """pydantic models for OpenAI API protocol""" + import time from typing import Dict, List, Optional, Union @@ -178,4 +179,4 @@ class ChatCompletionStreamResponse(BaseModel): object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionResponseStreamChoice] \ No newline at end of file + choices: List[ChatCompletionResponseStreamChoice] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d6eec0c90..f3a437ab0 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -30,15 +30,18 @@ from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api_adapter import ( - v1_completions, v1_chat_completions, load_chat_template_for_openai_api) + load_chat_template_for_openai_api, + v1_chat_completions, + v1_completions, +) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( + API_KEY_HEADER_NAME, + APIKeyValidatorMiddleware, allocate_init_ports, assert_pkg_version, enable_show_time_cost, get_exception_traceback, - API_KEY_HEADER_NAME, - APIKeyValidatorMiddleware ) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 09849a547..9a1e6400d 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -275,7 +275,9 @@ def is_multimodal_model(model): if isinstance(model, ModelConfig): model_path = model.path.lower() - return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path + return ( + "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path + ) raise ValueError("unrecognized type") diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 51bb9b20b..365ec16f4 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -138,6 +138,7 @@ def encode_frame(frame): def encode_video_base64(video_path, num_frames=16): import cv2 + cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise IOError(f"Could not open video file:{video_path}") diff --git a/test/srt/test_httpserver_reuse.py b/test/srt/test_httpserver_reuse.py index 36804e4b7..ef866afc6 100644 --- a/test/srt/test_httpserver_reuse.py +++ b/test/srt/test_httpserver_reuse.py @@ -9,7 +9,6 @@ import argparse import requests - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="http://127.0.0.1")