Correctly abort the failed grammar requests & Improve the handling of abort (#6803)
This commit is contained in:
@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from sglang.srt import two_batch_overlap
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
@@ -133,28 +132,27 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
if capture_bs is None:
|
||||
if server_args.speculative_algorithm is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + list(range(40, 161, 16))
|
||||
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
||||
else:
|
||||
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
||||
else:
|
||||
# Since speculative decoding requires more cuda graph memory, we
|
||||
# capture less.
|
||||
capture_bs = (
|
||||
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
||||
list(range(1, 9))
|
||||
+ list(range(10, 33, 2))
|
||||
+ list(range(40, 64, 8))
|
||||
+ list(range(80, 161, 16))
|
||||
)
|
||||
|
||||
gpu_mem = get_device_memory_capacity()
|
||||
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
||||
capture_bs += list(range(160, 257, 8))
|
||||
if gpu_mem is not None and gpu_mem > 180 * 1000:
|
||||
capture_bs += list(range(256, 528, 16))
|
||||
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
|
||||
model_runner.req_to_token_pool.size
|
||||
]
|
||||
capture_bs += [model_runner.req_to_token_pool.size]
|
||||
|
||||
if server_args.enable_two_batch_overlap:
|
||||
capture_bs = [bs for bs in capture_bs if bs >= 2]
|
||||
@@ -167,7 +165,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
)
|
||||
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
||||
capture_bs = list(sorted(set(capture_bs)))
|
||||
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
||||
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
||||
compile_bs = (
|
||||
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
||||
if server_args.enable_torch_compile
|
||||
|
||||
@@ -918,7 +918,7 @@ class ModelRunner:
|
||||
|
||||
if self.req_to_token_pool is None:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs + 1,
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
|
||||
Reference in New Issue
Block a user