Fix request abortion (#6184)
This commit is contained in:
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
@@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/abort_request")
|
||||
async def abort_request(obj: AbortReq, request: Request):
|
||||
"""Abort a request."""
|
||||
try:
|
||||
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.post("/parse_function_call")
|
||||
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from enum import Enum, auto
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
It will be transformed from CPU scheduler to GPU model runner.
|
||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
|
||||
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason):
|
||||
|
||||
|
||||
class FINISH_ABORT(BaseFinishReason):
|
||||
def __init__(self, message="Unknown error", status_code=None, err_type=None):
|
||||
def __init__(self, message=None, status_code=None, err_type=None):
|
||||
super().__init__(is_error=True)
|
||||
self.message = message
|
||||
self.message = message or "Aborted"
|
||||
self.status_code = status_code
|
||||
self.err_type = err_type
|
||||
|
||||
@@ -441,11 +442,13 @@ class Req:
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
# Whether this request has finished output
|
||||
self.finished_output = None
|
||||
# If we want to abort the request in the middle of the event loop, set this to true
|
||||
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
||||
self.to_abort = False
|
||||
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
||||
self.to_abort_message: str = "Unknown error"
|
||||
self.to_abort_message: str = None
|
||||
self.stream = stream
|
||||
self.eos_token_ids = eos_token_ids
|
||||
|
||||
@@ -546,8 +549,6 @@ class Req:
|
||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
||||
|
||||
# used for warmup because we don't have a pair yet when init
|
||||
self.skip_kv_transfer: bool = False
|
||||
# the start index of the sent kv cache
|
||||
# We want to send it chunk by chunk for chunked prefill.
|
||||
# After every chunk forward, we do the following:
|
||||
@@ -555,15 +556,15 @@ class Req:
|
||||
# start_send_idx = len(req.fill_ids)
|
||||
self.start_send_idx: int = 0
|
||||
|
||||
self.metadata_buffer_index: int = -1
|
||||
# The first output_id transferred from prefill instance.
|
||||
self.transferred_output_id: Optional[int] = None
|
||||
|
||||
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
||||
# This is because kv is not ready in `process_prefill_chunk`.
|
||||
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
||||
self.tmp_end_idx: int = -1
|
||||
|
||||
self.metadata_buffer_index: int = -1
|
||||
# The first output_id transferred from prefill instance.
|
||||
self.transferred_output_id: Optional[int] = None
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
@@ -697,13 +698,29 @@ class Req:
|
||||
self.req_pool_idx = None
|
||||
self.already_computed = 0
|
||||
|
||||
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
||||
token_indices = req_to_token_pool.req_to_token[
|
||||
self.req_pool_idx, : self.seqlen - 1
|
||||
]
|
||||
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
|
||||
|
||||
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
||||
token_indices = req_to_token_pool.req_to_token[
|
||||
self.req_pool_idx, : self.seqlen - 1
|
||||
]
|
||||
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
||||
del self.kv_cache_cpu
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Req(rid={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
|
||||
f"{self.grammar=}, "
|
||||
f"{self.sampling_params=})"
|
||||
)
|
||||
|
||||
|
||||
# Batch id
|
||||
bid = 0
|
||||
|
||||
|
||||
@@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
i
|
||||
for i in range(len(self.reqs))
|
||||
if not self.reqs[i].finished()
|
||||
and not self.reqs[i] in chunked_req_to_exclude
|
||||
and self.reqs[i] not in chunked_req_to_exclude
|
||||
]
|
||||
|
||||
if keep_indices is None or len(keep_indices) == 0:
|
||||
|
||||
@@ -20,7 +20,6 @@ import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict, deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -135,6 +130,7 @@ from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
crash_on_warnings,
|
||||
disable_request_logging,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
@@ -907,19 +903,6 @@ class Scheduler(
|
||||
fake_input_ids = [1] * seq_length
|
||||
recv_req.input_ids = fake_input_ids
|
||||
|
||||
# Handle custom logit processor passed to the request
|
||||
custom_logit_processor = recv_req.custom_logit_processor
|
||||
if (
|
||||
not self.server_args.enable_custom_logit_processor
|
||||
and custom_logit_processor is not None
|
||||
):
|
||||
logger.warning(
|
||||
"The SGLang server is not configured to enable custom logit processor."
|
||||
"The custom logit processor passed in will be ignored."
|
||||
"Please set --enable-custom-logits-processor to enable this feature."
|
||||
)
|
||||
custom_logit_processor = None
|
||||
|
||||
if recv_req.bootstrap_port is None:
|
||||
# Use default bootstrap port
|
||||
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
||||
@@ -935,7 +918,7 @@ class Scheduler(
|
||||
stream=recv_req.stream,
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
custom_logit_processor=recv_req.custom_logit_processor,
|
||||
return_hidden_states=recv_req.return_hidden_states,
|
||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||
bootstrap_host=recv_req.bootstrap_host,
|
||||
@@ -1246,9 +1229,7 @@ class Scheduler(
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
warnings.warn(msg)
|
||||
if crash_on_warnings():
|
||||
raise ValueError(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||
msg = (
|
||||
@@ -1256,9 +1237,7 @@ class Scheduler(
|
||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total_size={self.req_to_token_pool.size}\n"
|
||||
)
|
||||
warnings.warn(msg)
|
||||
if crash_on_warnings():
|
||||
raise ValueError(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
self.enable_metrics
|
||||
@@ -1774,24 +1753,27 @@ class Scheduler(
|
||||
if self.cur_batch is not None:
|
||||
if self.watchdog_last_forward_ct == self.forward_ct:
|
||||
if current > self.watchdog_last_time + self.watchdog_timeout:
|
||||
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||
break
|
||||
else:
|
||||
self.watchdog_last_forward_ct = self.forward_ct
|
||||
self.watchdog_last_time = current
|
||||
time.sleep(self.watchdog_timeout // 2)
|
||||
|
||||
# Print batch size and memory pool info to check whether there are de-sync issues.
|
||||
logger.error(
|
||||
f"{self.cur_batch.batch_size()=}, "
|
||||
f"{self.cur_batch.reqs=}, "
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
||||
f"{self.tree_cache.evictable_size()=}, "
|
||||
)
|
||||
# Wait for some time so that the parent process can print the error.
|
||||
if not disable_request_logging():
|
||||
# Print batch size and memory pool info to check whether there are de-sync issues.
|
||||
logger.error(
|
||||
f"{self.cur_batch.batch_size()=}, "
|
||||
f"{self.cur_batch.reqs=}, "
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
||||
f"{self.tree_cache.evictable_size()=}, "
|
||||
)
|
||||
|
||||
pyspy_dump_schedulers()
|
||||
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||
print(file=sys.stderr, flush=True)
|
||||
print(file=sys.stdout, flush=True)
|
||||
|
||||
# Wait for some time so that the parent process can print the error.
|
||||
time.sleep(5)
|
||||
self.parent_process.send_signal(signal.SIGQUIT)
|
||||
|
||||
@@ -1923,25 +1905,30 @@ class Scheduler(
|
||||
)
|
||||
|
||||
def abort_request(self, recv_req: AbortReq):
|
||||
# TODO(lmzheng): abort the requests in the grammar queue.
|
||||
|
||||
# Delete requests in the waiting queue
|
||||
to_del = []
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
to_del.append(i)
|
||||
break
|
||||
|
||||
# Sort in reverse order to avoid index issues when deleting
|
||||
for i in sorted(to_del, reverse=True):
|
||||
for i in reversed(to_del):
|
||||
req = self.waiting_queue.pop(i)
|
||||
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
||||
logger.debug(f"Abort queued request. {req.rid=}")
|
||||
return
|
||||
|
||||
# Delete requests in the running batch
|
||||
for req in self.running_batch.reqs:
|
||||
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
||||
reqs = self.running_batch.reqs
|
||||
else:
|
||||
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
||||
|
||||
for req in reqs:
|
||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||
logger.debug(f"Abort running request. {req.rid=}")
|
||||
req.to_abort = True
|
||||
return
|
||||
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -15,6 +15,8 @@ if TYPE_CHECKING:
|
||||
Scheduler,
|
||||
)
|
||||
|
||||
DEFAULT_FORCE_STREAM_INTERVAL = 50
|
||||
|
||||
|
||||
class SchedulerOutputProcessorMixin:
|
||||
"""
|
||||
@@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin:
|
||||
if self.model_config.is_multimodal_gen and req.to_abort:
|
||||
continue
|
||||
|
||||
if (
|
||||
req.finished()
|
||||
# If stream, follow the given stream_interval
|
||||
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
||||
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
||||
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
||||
# always increase one-by-one.
|
||||
or (
|
||||
not req.stream
|
||||
and len(req.output_ids) % 50 == 0
|
||||
and not self.model_config.is_multimodal_gen
|
||||
)
|
||||
):
|
||||
if req.finished():
|
||||
if req.finished_output:
|
||||
# With the overlap schedule, a request will try to output twice and hit this line twice
|
||||
# because of the one additional delayed token. This "continue" prevented the dummy output.
|
||||
continue
|
||||
req.finished_output = True
|
||||
should_output = True
|
||||
else:
|
||||
if req.stream:
|
||||
stream_interval = (
|
||||
req.sampling_params.stream_interval or self.stream_interval
|
||||
)
|
||||
should_output = len(req.output_ids) % stream_interval == 0
|
||||
else:
|
||||
should_output = (
|
||||
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
||||
and not self.model_config.is_multimodal_gen
|
||||
)
|
||||
|
||||
if should_output:
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
|
||||
@@ -288,6 +288,7 @@ class TokenizerManager:
|
||||
),
|
||||
self._handle_batch_output,
|
||||
),
|
||||
(AbortReq, self._handle_abort_req),
|
||||
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||
(
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
@@ -341,13 +342,14 @@ class TokenizerManager:
|
||||
]
|
||||
)
|
||||
|
||||
# For pd disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# for disaggregtion, start kv boostrap server on prefill
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
@@ -482,6 +484,14 @@ class TokenizerManager:
|
||||
session_params = (
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
if (
|
||||
obj.custom_logit_processor
|
||||
and not self.server_args.enable_custom_logit_processor
|
||||
):
|
||||
raise ValueError(
|
||||
"The server is not configured to enable custom logit processor. "
|
||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
@@ -570,9 +580,9 @@ class TokenizerManager:
|
||||
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||
created_time: Optional[float] = None,
|
||||
):
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||
self.rid_to_state[obj.rid] = state
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
|
||||
async def _wait_one_response(
|
||||
self,
|
||||
@@ -587,10 +597,11 @@ class TokenizerManager:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
||||
self.abort_request(obj.rid)
|
||||
# Use exception to kill the whole call stack and asyncio task
|
||||
raise ValueError(
|
||||
"Request is disconnected from the client side. "
|
||||
f"Abort request {obj.rid}"
|
||||
f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -605,7 +616,6 @@ class TokenizerManager:
|
||||
else:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||
logger.info(msg)
|
||||
del self.rid_to_state[obj.rid]
|
||||
|
||||
# Check if this was an abort/error created by scheduler
|
||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||
@@ -625,10 +635,11 @@ class TokenizerManager:
|
||||
yield out
|
||||
else:
|
||||
if request is not None and await request.is_disconnected():
|
||||
# Abort the request for disconnected requests (non-streaming, running)
|
||||
self.abort_request(obj.rid)
|
||||
# Use exception to kill the whole call stack and asyncio task
|
||||
raise ValueError(
|
||||
"Request is disconnected from the client side. "
|
||||
f"Abort request {obj.rid}"
|
||||
f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
|
||||
)
|
||||
|
||||
async def _handle_batch_request(
|
||||
@@ -728,7 +739,6 @@ class TokenizerManager:
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
@@ -964,7 +974,7 @@ class TokenizerManager:
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
async def abort_request():
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(2)
|
||||
if obj.is_single:
|
||||
self.abort_request(obj.rid)
|
||||
else:
|
||||
@@ -1035,6 +1045,9 @@ class TokenizerManager:
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
logger.error(
|
||||
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
||||
)
|
||||
continue
|
||||
|
||||
# Build meta_info and return value
|
||||
@@ -1098,6 +1111,7 @@ class TokenizerManager:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
state.finished_time = time.time()
|
||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
state.out_list.append(out_dict)
|
||||
state.event.set()
|
||||
@@ -1246,6 +1260,9 @@ class TokenizerManager:
|
||||
# Schedule the task to run in the background without awaiting it
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
self.rid_to_state.pop(recv_obj.rid)
|
||||
|
||||
def _handle_open_session_req_output(self, recv_obj):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id if recv_obj.success else None
|
||||
@@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_event.set()
|
||||
|
||||
|
||||
# Note: request abort handling logic
|
||||
# We should handle all of the following cases correctly.
|
||||
#
|
||||
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
||||
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
||||
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
||||
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
||||
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
||||
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
||||
#
|
||||
|
||||
@@ -50,6 +50,7 @@ class SamplingParams:
|
||||
spaces_between_special_tokens: bool = True,
|
||||
no_stop_trim: bool = False,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
stream_interval: Optional[int] = None,
|
||||
) -> None:
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.stop_strs = stop
|
||||
@@ -75,6 +76,7 @@ class SamplingParams:
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.no_stop_trim = no_stop_trim
|
||||
self.custom_params = custom_params
|
||||
self.stream_interval = stream_interval
|
||||
|
||||
# Process some special cases
|
||||
if 0 <= self.temperature < _SAMPLING_EPS:
|
||||
|
||||
@@ -27,6 +27,7 @@ class BenchArgs:
|
||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
||||
)
|
||||
image: bool = False
|
||||
many_images: bool = False
|
||||
stream: bool = False
|
||||
|
||||
@staticmethod
|
||||
@@ -48,6 +49,7 @@ class BenchArgs:
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
|
||||
parser.add_argument("--image", action="store_true")
|
||||
parser.add_argument("--many-images", action="store_true")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
|
||||
@classmethod
|
||||
@@ -62,6 +64,17 @@ def send_one_prompt(args):
|
||||
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
||||
)
|
||||
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
||||
elif args.many_images:
|
||||
args.prompt = (
|
||||
"Human: I have one reference image and many images."
|
||||
"Describe their relationship in a very short sentence.\n\nAssistant:"
|
||||
)
|
||||
image_data = [
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||
]
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
@@ -74,9 +87,6 @@ def send_one_prompt(args):
|
||||
"Write in a format of json.\nAssistant:"
|
||||
)
|
||||
json_schema = "$$ANY$$"
|
||||
json_schema = (
|
||||
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
|
||||
)
|
||||
else:
|
||||
json_schema = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user