Format code (#441)

This commit is contained in:
Liangsheng Yin
2024-05-14 22:40:46 +08:00
committed by GitHub
parent 664287b2a7
commit 690d162d97
17 changed files with 68 additions and 50 deletions

View File

@@ -74,4 +74,4 @@ class Anthropic(BaseBackend):
**sampling_params.to_anthropic_kwargs(),
) as stream:
for text in stream.text_stream:
yield text, {}
yield text, {}

View File

@@ -30,7 +30,11 @@ from sglang.lang.ir import (
SglVarScopeEnd,
SglVideo,
)
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
from sglang.utils import (
encode_image_base64,
encode_video_base64,
get_exception_traceback,
)
def run_internal(state, program, func_args, func_kwargs, sync):

View File

@@ -13,4 +13,4 @@ if __name__ == "__main__":
args = parser.parse_args()
response = requests.get(args.url + "/flush_cache")
assert response.status_code == 200
assert response.status_code == 200

View File

@@ -32,7 +32,9 @@ class GenerateReqInput:
def post_init(self):
if self.text is None:
assert self.input_ids is not None, "Either text or input_ids should be provided"
assert (
self.input_ids is not None
), "Either text or input_ids should be provided"
else:
assert self.input_ids is None, "Either text or input_ids should be provided"

View File

@@ -24,7 +24,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
@@ -37,7 +37,6 @@ from sglang.srt.utils import (
set_random_seed,
)
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
@@ -238,7 +237,9 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
throuhgput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
@@ -401,12 +402,12 @@ class ModelRpcServer:
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
)
#logger.debug(
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
#)
# )
new_batch = Batch.init_new(
can_run_list,
@@ -440,11 +441,10 @@ class ModelRpcServer:
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = (
last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids].tolist()
)
last_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
else:

View File

@@ -17,8 +17,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
QUANTIZATION_CONFIG_MAPPING = {
"awq": AWQConfig,

View File

@@ -72,7 +72,8 @@ def get_pixel_values(
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
image,
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
@@ -208,10 +209,12 @@ class TokenizerManager:
while True:
await event.wait()
out = self.convert_logprob_style(state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs)
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj.text}, out={out}")
@@ -275,10 +278,13 @@ class TokenizerManager:
state = self.rid_to_state[rid]
await state.event.wait()
output_list.append(
self.convert_logprob_style(state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs))
self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs,
)
)
assert state.finished
del self.rid_to_state[rid]
@@ -311,7 +317,9 @@ class TokenizerManager:
else:
raise ValueError(f"Invalid object: {recv_obj}")
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
@@ -320,11 +328,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["prefill_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
)
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["decode_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
)
return ret

View File

@@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward():
)
EntryClass = LlavaLlamaForCausalLM
EntryClass = LlavaLlamaForCausalLM

View File

@@ -5,13 +5,9 @@ from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaMistralForCausalLM(nn.Module):

View File

@@ -5,13 +5,9 @@ from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaQwenForCausalLM(nn.Module):

View File

@@ -1,4 +1,5 @@
"""Conversion between OpenAI APIs and native SRT APIs"""
import json
import os
@@ -31,9 +32,9 @@ from sglang.srt.openai_protocol import (
)
from sglang.srt.utils import jsonify_pydantic_model
chat_template_name = None
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
@@ -353,4 +354,4 @@ def to_openai_style_logprobs(
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
return ret_logprobs
return ret_logprobs

View File

@@ -1,4 +1,5 @@
"""pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
@@ -178,4 +179,4 @@ class ChatCompletionStreamResponse(BaseModel):
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
choices: List[ChatCompletionResponseStreamChoice]

View File

@@ -30,15 +30,18 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import (
v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
load_chat_template_for_openai_api,
v1_chat_completions,
v1_completions,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
get_exception_traceback,
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

View File

@@ -275,7 +275,9 @@ def is_multimodal_model(model):
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
return (
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
)
raise ValueError("unrecognized type")

View File

@@ -138,6 +138,7 @@ def encode_frame(frame):
def encode_video_base64(video_path, num_frames=16):
import cv2
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")