Format code (#441)
This commit is contained in:
@@ -74,4 +74,4 @@ class Anthropic(BaseBackend):
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
yield text, {}
|
||||
yield text, {}
|
||||
|
||||
@@ -30,7 +30,11 @@ from sglang.lang.ir import (
|
||||
SglVarScopeEnd,
|
||||
SglVideo,
|
||||
)
|
||||
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
|
||||
from sglang.utils import (
|
||||
encode_image_base64,
|
||||
encode_video_base64,
|
||||
get_exception_traceback,
|
||||
)
|
||||
|
||||
|
||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||
|
||||
@@ -13,4 +13,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
response = requests.get(args.url + "/flush_cache")
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -32,7 +32,9 @@ class GenerateReqInput:
|
||||
def post_init(self):
|
||||
|
||||
if self.text is None:
|
||||
assert self.input_ids is not None, "Either text or input_ids should be provided"
|
||||
assert (
|
||||
self.input_ids is not None
|
||||
), "Either text or input_ids should be provided"
|
||||
else:
|
||||
assert self.input_ids is None, "Either text or input_ids should be provided"
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
|
||||
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
@@ -37,7 +37,6 @@ from sglang.srt.utils import (
|
||||
set_random_seed,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
@@ -238,7 +237,9 @@ class ModelRpcServer:
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||
throuhgput = self.num_generated_tokens / (
|
||||
time.time() - self.last_stats_tic
|
||||
)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
logger.info(
|
||||
@@ -401,12 +402,12 @@ class ModelRpcServer:
|
||||
f"#running_req: {running_req}. "
|
||||
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
||||
)
|
||||
#logger.debug(
|
||||
# logger.debug(
|
||||
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
||||
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
||||
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||
#)
|
||||
# )
|
||||
|
||||
new_batch = Batch.init_new(
|
||||
can_run_list,
|
||||
@@ -440,11 +441,10 @@ class ModelRpcServer:
|
||||
|
||||
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
if last_logprobs is not None:
|
||||
last_token_logprobs = (
|
||||
last_logprobs[
|
||||
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
||||
next_token_ids].tolist()
|
||||
)
|
||||
last_token_logprobs = last_logprobs[
|
||||
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
|
||||
@@ -17,8 +17,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
|
||||
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
||||
|
||||
QUANTIZATION_CONFIG_MAPPING = {
|
||||
"awq": AWQConfig,
|
||||
|
||||
@@ -72,7 +72,8 @@ def get_pixel_values(
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
||||
image,
|
||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
@@ -208,10 +209,12 @@ class TokenizerManager:
|
||||
|
||||
while True:
|
||||
await event.wait()
|
||||
out = self.convert_logprob_style(state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs)
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
|
||||
if self.server_args.log_requests and state.finished:
|
||||
logger.info(f"in={obj.text}, out={out}")
|
||||
@@ -275,10 +278,13 @@ class TokenizerManager:
|
||||
state = self.rid_to_state[rid]
|
||||
await state.event.wait()
|
||||
output_list.append(
|
||||
self.convert_logprob_style(state.out_list[-1],
|
||||
obj.return_logprob[i],
|
||||
obj.top_logprobs_num[i],
|
||||
obj.return_text_in_logprobs))
|
||||
self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob[i],
|
||||
obj.top_logprobs_num[i],
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
@@ -311,7 +317,9 @@ class TokenizerManager:
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
|
||||
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
|
||||
def convert_logprob_style(
|
||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||
):
|
||||
if return_logprob:
|
||||
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
||||
@@ -320,11 +328,15 @@ class TokenizerManager:
|
||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
ret["meta_info"]["prefill_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
)
|
||||
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
ret["meta_info"]["decode_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward():
|
||||
)
|
||||
|
||||
|
||||
EntryClass = LlavaLlamaForCausalLM
|
||||
EntryClass = LlavaLlamaForCausalLM
|
||||
|
||||
@@ -5,13 +5,9 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class LlavaMistralForCausalLM(nn.Module):
|
||||
|
||||
@@ -5,13 +5,9 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class LlavaQwenForCausalLM(nn.Module):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Conversion between OpenAI APIs and native SRT APIs"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -31,9 +32,9 @@ from sglang.srt.openai_protocol import (
|
||||
)
|
||||
from sglang.srt.utils import jsonify_pydantic_model
|
||||
|
||||
|
||||
chat_template_name = None
|
||||
|
||||
|
||||
def load_chat_template_for_openai_api(chat_template_arg):
|
||||
global chat_template_name
|
||||
|
||||
@@ -353,4 +354,4 @@ def to_openai_style_logprobs(
|
||||
if decode_top_logprobs is not None:
|
||||
append_top_logprobs(decode_top_logprobs)
|
||||
|
||||
return ret_logprobs
|
||||
return ret_logprobs
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""pydantic models for OpenAI API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -178,4 +179,4 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
@@ -30,15 +30,18 @@ from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api_adapter import (
|
||||
v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
|
||||
load_chat_template_for_openai_api,
|
||||
v1_chat_completions,
|
||||
v1_completions,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
API_KEY_HEADER_NAME,
|
||||
APIKeyValidatorMiddleware,
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
enable_show_time_cost,
|
||||
get_exception_traceback,
|
||||
API_KEY_HEADER_NAME,
|
||||
APIKeyValidatorMiddleware
|
||||
)
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -275,7 +275,9 @@ def is_multimodal_model(model):
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
model_path = model.path.lower()
|
||||
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
||||
return (
|
||||
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
||||
)
|
||||
|
||||
raise ValueError("unrecognized type")
|
||||
|
||||
|
||||
@@ -138,6 +138,7 @@ def encode_frame(frame):
|
||||
|
||||
def encode_video_base64(video_path, num_frames=16):
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Could not open video file:{video_path}")
|
||||
|
||||
Reference in New Issue
Block a user