Fix request abortion (#6184)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
part: [0, 1, 2, 3, 4, 5, 6, 7]
|
part: [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -69,7 +69,7 @@ jobs:
|
|||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 8
|
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9
|
||||||
|
|
||||||
unit-test-backend-2-gpu:
|
unit-test-backend-2-gpu:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
from sglang.srt.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
AbortReq,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ConfigureLoggingReq,
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
@@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/abort_request")
|
||||||
|
async def abort_request(obj: AbortReq, request: Request):
|
||||||
|
"""Abort a request."""
|
||||||
|
try:
|
||||||
|
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
|
||||||
|
return Response(status_code=200)
|
||||||
|
except Exception as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/parse_function_call")
|
@app.post("/parse_function_call")
|
||||||
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
|
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from enum import Enum, auto
|
|
||||||
|
|
||||||
# Copyright 2023-2024 SGLang Team
|
# Copyright 2023-2024 SGLang Team
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
It will be transformed from CPU scheduler to GPU model runner.
|
It will be transformed from CPU scheduler to GPU model runner.
|
||||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||||
|
|
||||||
|
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason):
|
|||||||
|
|
||||||
|
|
||||||
class FINISH_ABORT(BaseFinishReason):
|
class FINISH_ABORT(BaseFinishReason):
|
||||||
def __init__(self, message="Unknown error", status_code=None, err_type=None):
|
def __init__(self, message=None, status_code=None, err_type=None):
|
||||||
super().__init__(is_error=True)
|
super().__init__(is_error=True)
|
||||||
self.message = message
|
self.message = message or "Aborted"
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.err_type = err_type
|
self.err_type = err_type
|
||||||
|
|
||||||
@@ -441,11 +442,13 @@ class Req:
|
|||||||
# Check finish
|
# Check finish
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.finished_reason = None
|
self.finished_reason = None
|
||||||
|
# Whether this request has finished output
|
||||||
|
self.finished_output = None
|
||||||
# If we want to abort the request in the middle of the event loop, set this to true
|
# If we want to abort the request in the middle of the event loop, set this to true
|
||||||
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
||||||
self.to_abort = False
|
self.to_abort = False
|
||||||
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
||||||
self.to_abort_message: str = "Unknown error"
|
self.to_abort_message: str = None
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.eos_token_ids = eos_token_ids
|
self.eos_token_ids = eos_token_ids
|
||||||
|
|
||||||
@@ -546,8 +549,6 @@ class Req:
|
|||||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||||
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
||||||
|
|
||||||
# used for warmup because we don't have a pair yet when init
|
|
||||||
self.skip_kv_transfer: bool = False
|
|
||||||
# the start index of the sent kv cache
|
# the start index of the sent kv cache
|
||||||
# We want to send it chunk by chunk for chunked prefill.
|
# We want to send it chunk by chunk for chunked prefill.
|
||||||
# After every chunk forward, we do the following:
|
# After every chunk forward, we do the following:
|
||||||
@@ -555,15 +556,15 @@ class Req:
|
|||||||
# start_send_idx = len(req.fill_ids)
|
# start_send_idx = len(req.fill_ids)
|
||||||
self.start_send_idx: int = 0
|
self.start_send_idx: int = 0
|
||||||
|
|
||||||
self.metadata_buffer_index: int = -1
|
|
||||||
# The first output_id transferred from prefill instance.
|
|
||||||
self.transferred_output_id: Optional[int] = None
|
|
||||||
|
|
||||||
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
||||||
# This is because kv is not ready in `process_prefill_chunk`.
|
# This is because kv is not ready in `process_prefill_chunk`.
|
||||||
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
||||||
self.tmp_end_idx: int = -1
|
self.tmp_end_idx: int = -1
|
||||||
|
|
||||||
|
self.metadata_buffer_index: int = -1
|
||||||
|
# The first output_id transferred from prefill instance.
|
||||||
|
self.transferred_output_id: Optional[int] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def seqlen(self):
|
def seqlen(self):
|
||||||
return len(self.origin_input_ids) + len(self.output_ids)
|
return len(self.origin_input_ids) + len(self.output_ids)
|
||||||
@@ -697,13 +698,29 @@ class Req:
|
|||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
self.already_computed = 0
|
self.already_computed = 0
|
||||||
|
|
||||||
|
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
||||||
|
token_indices = req_to_token_pool.req_to_token[
|
||||||
|
self.req_pool_idx, : self.seqlen - 1
|
||||||
|
]
|
||||||
|
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
|
||||||
|
|
||||||
|
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
||||||
|
token_indices = req_to_token_pool.req_to_token[
|
||||||
|
self.req_pool_idx, : self.seqlen - 1
|
||||||
|
]
|
||||||
|
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
||||||
|
del self.kv_cache_cpu
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"Req(rid={self.rid}, "
|
f"Req(rid={self.rid}, "
|
||||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
|
||||||
|
f"{self.grammar=}, "
|
||||||
|
f"{self.sampling_params=})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Batch id
|
||||||
bid = 0
|
bid = 0
|
||||||
|
|
||||||
|
|
||||||
@@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
i
|
i
|
||||||
for i in range(len(self.reqs))
|
for i in range(len(self.reqs))
|
||||||
if not self.reqs[i].finished()
|
if not self.reqs[i].finished()
|
||||||
and not self.reqs[i] in chunked_req_to_exclude
|
and self.reqs[i] not in chunked_req_to_exclude
|
||||||
]
|
]
|
||||||
|
|
||||||
if keep_indices is None or len(keep_indices) == 0:
|
if keep_indices is None or len(keep_indices) == 0:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||||
ForwardBatch,
|
|
||||||
ForwardMode,
|
|
||||||
PPProxyTensors,
|
|
||||||
)
|
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
@@ -135,6 +130,7 @@ from sglang.srt.utils import (
|
|||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
crash_on_warnings,
|
crash_on_warnings,
|
||||||
|
disable_request_logging,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
@@ -907,19 +903,6 @@ class Scheduler(
|
|||||||
fake_input_ids = [1] * seq_length
|
fake_input_ids = [1] * seq_length
|
||||||
recv_req.input_ids = fake_input_ids
|
recv_req.input_ids = fake_input_ids
|
||||||
|
|
||||||
# Handle custom logit processor passed to the request
|
|
||||||
custom_logit_processor = recv_req.custom_logit_processor
|
|
||||||
if (
|
|
||||||
not self.server_args.enable_custom_logit_processor
|
|
||||||
and custom_logit_processor is not None
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"The SGLang server is not configured to enable custom logit processor."
|
|
||||||
"The custom logit processor passed in will be ignored."
|
|
||||||
"Please set --enable-custom-logits-processor to enable this feature."
|
|
||||||
)
|
|
||||||
custom_logit_processor = None
|
|
||||||
|
|
||||||
if recv_req.bootstrap_port is None:
|
if recv_req.bootstrap_port is None:
|
||||||
# Use default bootstrap port
|
# Use default bootstrap port
|
||||||
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
||||||
@@ -935,7 +918,7 @@ class Scheduler(
|
|||||||
stream=recv_req.stream,
|
stream=recv_req.stream,
|
||||||
lora_path=recv_req.lora_path,
|
lora_path=recv_req.lora_path,
|
||||||
input_embeds=recv_req.input_embeds,
|
input_embeds=recv_req.input_embeds,
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=recv_req.custom_logit_processor,
|
||||||
return_hidden_states=recv_req.return_hidden_states,
|
return_hidden_states=recv_req.return_hidden_states,
|
||||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||||
bootstrap_host=recv_req.bootstrap_host,
|
bootstrap_host=recv_req.bootstrap_host,
|
||||||
@@ -1246,9 +1229,7 @@ class Scheduler(
|
|||||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||||
f"{self.tree_cache.evictable_size()=}\n"
|
f"{self.tree_cache.evictable_size()=}\n"
|
||||||
)
|
)
|
||||||
warnings.warn(msg)
|
raise ValueError(msg)
|
||||||
if crash_on_warnings():
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||||
msg = (
|
msg = (
|
||||||
@@ -1256,9 +1237,7 @@ class Scheduler(
|
|||||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||||
f"total_size={self.req_to_token_pool.size}\n"
|
f"total_size={self.req_to_token_pool.size}\n"
|
||||||
)
|
)
|
||||||
warnings.warn(msg)
|
raise ValueError(msg)
|
||||||
if crash_on_warnings():
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.enable_metrics
|
self.enable_metrics
|
||||||
@@ -1774,24 +1753,27 @@ class Scheduler(
|
|||||||
if self.cur_batch is not None:
|
if self.cur_batch is not None:
|
||||||
if self.watchdog_last_forward_ct == self.forward_ct:
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
||||||
if current > self.watchdog_last_time + self.watchdog_timeout:
|
if current > self.watchdog_last_time + self.watchdog_timeout:
|
||||||
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.watchdog_last_forward_ct = self.forward_ct
|
self.watchdog_last_forward_ct = self.forward_ct
|
||||||
self.watchdog_last_time = current
|
self.watchdog_last_time = current
|
||||||
time.sleep(self.watchdog_timeout // 2)
|
time.sleep(self.watchdog_timeout // 2)
|
||||||
|
|
||||||
# Print batch size and memory pool info to check whether there are de-sync issues.
|
if not disable_request_logging():
|
||||||
logger.error(
|
# Print batch size and memory pool info to check whether there are de-sync issues.
|
||||||
f"{self.cur_batch.batch_size()=}, "
|
logger.error(
|
||||||
f"{self.cur_batch.reqs=}, "
|
f"{self.cur_batch.batch_size()=}, "
|
||||||
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
f"{self.cur_batch.reqs=}, "
|
||||||
f"{self.tree_cache.evictable_size()=}, "
|
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
||||||
)
|
f"{self.tree_cache.evictable_size()=}, "
|
||||||
# Wait for some time so that the parent process can print the error.
|
)
|
||||||
|
|
||||||
pyspy_dump_schedulers()
|
pyspy_dump_schedulers()
|
||||||
|
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||||
print(file=sys.stderr, flush=True)
|
print(file=sys.stderr, flush=True)
|
||||||
print(file=sys.stdout, flush=True)
|
print(file=sys.stdout, flush=True)
|
||||||
|
|
||||||
|
# Wait for some time so that the parent process can print the error.
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
self.parent_process.send_signal(signal.SIGQUIT)
|
self.parent_process.send_signal(signal.SIGQUIT)
|
||||||
|
|
||||||
@@ -1923,25 +1905,30 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def abort_request(self, recv_req: AbortReq):
|
def abort_request(self, recv_req: AbortReq):
|
||||||
|
# TODO(lmzheng): abort the requests in the grammar queue.
|
||||||
|
|
||||||
# Delete requests in the waiting queue
|
# Delete requests in the waiting queue
|
||||||
to_del = []
|
to_del = []
|
||||||
for i, req in enumerate(self.waiting_queue):
|
for i, req in enumerate(self.waiting_queue):
|
||||||
if req.rid.startswith(recv_req.rid):
|
if req.rid.startswith(recv_req.rid):
|
||||||
to_del.append(i)
|
to_del.append(i)
|
||||||
break
|
|
||||||
|
|
||||||
# Sort in reverse order to avoid index issues when deleting
|
# Sort in reverse order to avoid index issues when deleting
|
||||||
for i in sorted(to_del, reverse=True):
|
for i in reversed(to_del):
|
||||||
req = self.waiting_queue.pop(i)
|
req = self.waiting_queue.pop(i)
|
||||||
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
||||||
logger.debug(f"Abort queued request. {req.rid=}")
|
logger.debug(f"Abort queued request. {req.rid=}")
|
||||||
return
|
|
||||||
|
|
||||||
# Delete requests in the running batch
|
# Delete requests in the running batch
|
||||||
for req in self.running_batch.reqs:
|
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
||||||
|
reqs = self.running_batch.reqs
|
||||||
|
else:
|
||||||
|
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
||||||
|
|
||||||
|
for req in reqs:
|
||||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||||
logger.debug(f"Abort running request. {req.rid=}")
|
logger.debug(f"Abort running request. {req.rid=}")
|
||||||
req.to_abort = True
|
req.to_abort = True
|
||||||
return
|
|
||||||
|
|
||||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ if TYPE_CHECKING:
|
|||||||
Scheduler,
|
Scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_FORCE_STREAM_INTERVAL = 50
|
||||||
|
|
||||||
|
|
||||||
class SchedulerOutputProcessorMixin:
|
class SchedulerOutputProcessorMixin:
|
||||||
"""
|
"""
|
||||||
@@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin:
|
|||||||
if self.model_config.is_multimodal_gen and req.to_abort:
|
if self.model_config.is_multimodal_gen and req.to_abort:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if req.finished():
|
||||||
req.finished()
|
if req.finished_output:
|
||||||
# If stream, follow the given stream_interval
|
# With the overlap schedule, a request will try to output twice and hit this line twice
|
||||||
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
# because of the one additional delayed token. This "continue" prevented the dummy output.
|
||||||
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
continue
|
||||||
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
req.finished_output = True
|
||||||
# always increase one-by-one.
|
should_output = True
|
||||||
or (
|
else:
|
||||||
not req.stream
|
if req.stream:
|
||||||
and len(req.output_ids) % 50 == 0
|
stream_interval = (
|
||||||
and not self.model_config.is_multimodal_gen
|
req.sampling_params.stream_interval or self.stream_interval
|
||||||
)
|
)
|
||||||
):
|
should_output = len(req.output_ids) % stream_interval == 0
|
||||||
|
else:
|
||||||
|
should_output = (
|
||||||
|
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
||||||
|
and not self.model_config.is_multimodal_gen
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_output:
|
||||||
rids.append(req.rid)
|
rids.append(req.rid)
|
||||||
finished_reasons.append(
|
finished_reasons.append(
|
||||||
req.finished_reason.to_json() if req.finished_reason else None
|
req.finished_reason.to_json() if req.finished_reason else None
|
||||||
|
|||||||
@@ -288,6 +288,7 @@ class TokenizerManager:
|
|||||||
),
|
),
|
||||||
self._handle_batch_output,
|
self._handle_batch_output,
|
||||||
),
|
),
|
||||||
|
(AbortReq, self._handle_abort_req),
|
||||||
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||||
(
|
(
|
||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
@@ -341,13 +342,14 @@ class TokenizerManager:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For pd disaggregtion
|
||||||
self.disaggregation_mode = DisaggregationMode(
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
self.server_args.disaggregation_mode
|
self.server_args.disaggregation_mode
|
||||||
)
|
)
|
||||||
self.transfer_backend = TransferBackend(
|
self.transfer_backend = TransferBackend(
|
||||||
self.server_args.disaggregation_transfer_backend
|
self.server_args.disaggregation_transfer_backend
|
||||||
)
|
)
|
||||||
# for disaggregtion, start kv boostrap server on prefill
|
# Start kv boostrap server on prefill
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
# only start bootstrap server on prefill tm
|
# only start bootstrap server on prefill tm
|
||||||
kv_bootstrap_server_class = get_kv_class(
|
kv_bootstrap_server_class = get_kv_class(
|
||||||
@@ -482,6 +484,14 @@ class TokenizerManager:
|
|||||||
session_params = (
|
session_params = (
|
||||||
SessionParams(**obj.session_params) if obj.session_params else None
|
SessionParams(**obj.session_params) if obj.session_params else None
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
obj.custom_logit_processor
|
||||||
|
and not self.server_args.enable_custom_logit_processor
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The server is not configured to enable custom logit processor. "
|
||||||
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params)
|
sampling_params = SamplingParams(**obj.sampling_params)
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
@@ -570,9 +580,9 @@ class TokenizerManager:
|
|||||||
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||||
created_time: Optional[float] = None,
|
created_time: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||||
self.rid_to_state[obj.rid] = state
|
self.rid_to_state[obj.rid] = state
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
|
||||||
|
|
||||||
async def _wait_one_response(
|
async def _wait_one_response(
|
||||||
self,
|
self,
|
||||||
@@ -587,10 +597,11 @@ class TokenizerManager:
|
|||||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
|
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
||||||
self.abort_request(obj.rid)
|
self.abort_request(obj.rid)
|
||||||
|
# Use exception to kill the whole call stack and asyncio task
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Request is disconnected from the client side. "
|
f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
|
||||||
f"Abort request {obj.rid}"
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -605,7 +616,6 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
del self.rid_to_state[obj.rid]
|
|
||||||
|
|
||||||
# Check if this was an abort/error created by scheduler
|
# Check if this was an abort/error created by scheduler
|
||||||
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||||
@@ -625,10 +635,11 @@ class TokenizerManager:
|
|||||||
yield out
|
yield out
|
||||||
else:
|
else:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
|
# Abort the request for disconnected requests (non-streaming, running)
|
||||||
self.abort_request(obj.rid)
|
self.abort_request(obj.rid)
|
||||||
|
# Use exception to kill the whole call stack and asyncio task
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Request is disconnected from the client side. "
|
f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
|
||||||
f"Abort request {obj.rid}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_batch_request(
|
async def _handle_batch_request(
|
||||||
@@ -728,7 +739,6 @@ class TokenizerManager:
|
|||||||
def abort_request(self, rid: str):
|
def abort_request(self, rid: str):
|
||||||
if rid not in self.rid_to_state:
|
if rid not in self.rid_to_state:
|
||||||
return
|
return
|
||||||
del self.rid_to_state[rid]
|
|
||||||
req = AbortReq(rid)
|
req = AbortReq(rid)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
@@ -964,7 +974,7 @@ class TokenizerManager:
|
|||||||
def create_abort_task(self, obj: GenerateReqInput):
|
def create_abort_task(self, obj: GenerateReqInput):
|
||||||
# Abort the request if the client is disconnected.
|
# Abort the request if the client is disconnected.
|
||||||
async def abort_request():
|
async def abort_request():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(2)
|
||||||
if obj.is_single:
|
if obj.is_single:
|
||||||
self.abort_request(obj.rid)
|
self.abort_request(obj.rid)
|
||||||
else:
|
else:
|
||||||
@@ -1035,6 +1045,9 @@ class TokenizerManager:
|
|||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
state = self.rid_to_state.get(rid, None)
|
state = self.rid_to_state.get(rid, None)
|
||||||
if state is None:
|
if state is None:
|
||||||
|
logger.error(
|
||||||
|
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Build meta_info and return value
|
# Build meta_info and return value
|
||||||
@@ -1098,6 +1111,7 @@ class TokenizerManager:
|
|||||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||||
state.finished_time = time.time()
|
state.finished_time = time.time()
|
||||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||||
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
state.out_list.append(out_dict)
|
state.out_list.append(out_dict)
|
||||||
state.event.set()
|
state.event.set()
|
||||||
@@ -1246,6 +1260,9 @@ class TokenizerManager:
|
|||||||
# Schedule the task to run in the background without awaiting it
|
# Schedule the task to run in the background without awaiting it
|
||||||
asyncio.create_task(asyncio.to_thread(background_task))
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
|
def _handle_abort_req(self, recv_obj):
|
||||||
|
self.rid_to_state.pop(recv_obj.rid)
|
||||||
|
|
||||||
def _handle_open_session_req_output(self, recv_obj):
|
def _handle_open_session_req_output(self, recv_obj):
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
recv_obj.session_id if recv_obj.success else None
|
recv_obj.session_id if recv_obj.success else None
|
||||||
@@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]):
|
|||||||
self._result_values.append(recv_obj)
|
self._result_values.append(recv_obj)
|
||||||
if len(self._result_values) == self._fan_out:
|
if len(self._result_values) == self._fan_out:
|
||||||
self._result_event.set()
|
self._result_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
# Note: request abort handling logic
|
||||||
|
# We should handle all of the following cases correctly.
|
||||||
|
#
|
||||||
|
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
||||||
|
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
||||||
|
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
||||||
|
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
||||||
|
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
||||||
|
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
||||||
|
#
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class SamplingParams:
|
|||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
no_stop_trim: bool = False,
|
no_stop_trim: bool = False,
|
||||||
custom_params: Optional[Dict[str, Any]] = None,
|
custom_params: Optional[Dict[str, Any]] = None,
|
||||||
|
stream_interval: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.stop_strs = stop
|
self.stop_strs = stop
|
||||||
@@ -75,6 +76,7 @@ class SamplingParams:
|
|||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.no_stop_trim = no_stop_trim
|
self.no_stop_trim = no_stop_trim
|
||||||
self.custom_params = custom_params
|
self.custom_params = custom_params
|
||||||
|
self.stream_interval = stream_interval
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
if 0 <= self.temperature < _SAMPLING_EPS:
|
if 0 <= self.temperature < _SAMPLING_EPS:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class BenchArgs:
|
|||||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
||||||
)
|
)
|
||||||
image: bool = False
|
image: bool = False
|
||||||
|
many_images: bool = False
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -48,6 +49,7 @@ class BenchArgs:
|
|||||||
parser.add_argument("--return-logprob", action="store_true")
|
parser.add_argument("--return-logprob", action="store_true")
|
||||||
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
|
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
|
||||||
parser.add_argument("--image", action="store_true")
|
parser.add_argument("--image", action="store_true")
|
||||||
|
parser.add_argument("--many-images", action="store_true")
|
||||||
parser.add_argument("--stream", action="store_true")
|
parser.add_argument("--stream", action="store_true")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -62,6 +64,17 @@ def send_one_prompt(args):
|
|||||||
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
||||||
)
|
)
|
||||||
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
||||||
|
elif args.many_images:
|
||||||
|
args.prompt = (
|
||||||
|
"Human: I have one reference image and many images."
|
||||||
|
"Describe their relationship in a very short sentence.\n\nAssistant:"
|
||||||
|
)
|
||||||
|
image_data = [
|
||||||
|
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||||
|
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||||
|
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||||
|
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
image_data = None
|
image_data = None
|
||||||
|
|
||||||
@@ -74,9 +87,6 @@ def send_one_prompt(args):
|
|||||||
"Write in a format of json.\nAssistant:"
|
"Write in a format of json.\nAssistant:"
|
||||||
)
|
)
|
||||||
json_schema = "$$ANY$$"
|
json_schema = "$$ANY$$"
|
||||||
json_schema = (
|
|
||||||
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
json_schema = None
|
json_schema = None
|
||||||
|
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ class TestBenchServing(CustomTestCase):
|
|||||||
f"### test_vlm_online_latency\n"
|
f"### test_vlm_online_latency\n"
|
||||||
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
|
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
|
||||||
)
|
)
|
||||||
self.assertLess(res["median_e2e_latency_ms"], 16000)
|
self.assertLess(res["median_e2e_latency_ms"], 16500)
|
||||||
if os.getenv("SGLANG_AMD_CI") == "1":
|
if os.getenv("SGLANG_AMD_CI") == "1":
|
||||||
self.assertLess(res["median_ttft_ms"], 150)
|
self.assertLess(res["median_ttft_ms"], 150)
|
||||||
# TODO: not set yet, need AMD machine
|
# TODO: not set yet, need AMD machine
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ Usage:
|
|||||||
python3 test/srt/test_flashmla.py
|
python3 test/srt/test_flashmla.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
@@ -61,7 +60,7 @@ class TestFlashMLAAttnBackend(unittest.TestCase):
|
|||||||
metrics = run_eval_few_shot_gsm8k(args)
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
self.assertGreater(metrics["accuracy"], 0.62)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
|
||||||
class TestFlashMLAAttnLatency(unittest.TestCase):
|
class TestFlashMLAAttnLatency(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user