Provide an argument to set the maximum batch size for cuda graph (#1809)
This commit is contained in:
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
|
|||||||
# This can prevent the server from being too conservative.
|
# This can prevent the server from being too conservative.
|
||||||
# Note that this only clips the estimation in the scheduler but does not change the stop
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
||||||
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
||||||
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
|
CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
|
||||||
|
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SchedulePolicy:
|
class SchedulePolicy:
|
||||||
@@ -146,7 +148,7 @@ class PrefillAdder:
|
|||||||
[
|
[
|
||||||
min(
|
min(
|
||||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||||
CLIP_MAX_NEW_TOKENS,
|
CLIP_MAX_NEW_TOKENS_ESTIMATION,
|
||||||
)
|
)
|
||||||
* self.new_token_ratio
|
* self.new_token_ratio
|
||||||
for r in running_batch.reqs
|
for r in running_batch.reqs
|
||||||
@@ -186,7 +188,7 @@ class PrefillAdder:
|
|||||||
len(req.prefix_indices),
|
len(req.prefix_indices),
|
||||||
req.extend_input_len,
|
req.extend_input_len,
|
||||||
(
|
(
|
||||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
|
||||||
if not truncated
|
if not truncated
|
||||||
else 0
|
else 0
|
||||||
),
|
),
|
||||||
@@ -258,7 +260,7 @@ class PrefillAdder:
|
|||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
0,
|
0,
|
||||||
req.extend_input_len,
|
req.extend_input_len,
|
||||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
@@ -276,7 +278,7 @@ class PrefillAdder:
|
|||||||
return self.add_one_req_ignore_eos(req)
|
return self.add_one_req_ignore_eos(req)
|
||||||
|
|
||||||
total_tokens = req.extend_input_len + min(
|
total_tokens = req.extend_input_len + min(
|
||||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
||||||
)
|
)
|
||||||
input_tokens = req.extend_input_len
|
input_tokens = req.extend_input_len
|
||||||
prefix_len = len(req.prefix_indices)
|
prefix_len = len(req.prefix_indices)
|
||||||
@@ -302,7 +304,10 @@ class PrefillAdder:
|
|||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
prefix_len,
|
prefix_len,
|
||||||
input_tokens,
|
input_tokens,
|
||||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
min(
|
||||||
|
req.sampling_params.max_new_tokens,
|
||||||
|
CLIP_MAX_NEW_TOKENS_ESTIMATION,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
|
|||||||
@@ -113,12 +113,15 @@ class CudaGraphRunner:
|
|||||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
if self.model_runner.server_args.disable_cuda_graph_padding:
|
if model_runner.server_args.disable_cuda_graph_padding:
|
||||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||||
else:
|
else:
|
||||||
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
self.capture_bs = [
|
self.capture_bs = [
|
||||||
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
bs
|
||||||
|
for bs in self.capture_bs
|
||||||
|
if bs <= model_runner.req_to_token_pool.size
|
||||||
|
and bs <= model_runner.server_args.max_cuda_graph_bs
|
||||||
]
|
]
|
||||||
self.compile_bs = (
|
self.compile_bs = (
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ class ServerArgs:
|
|||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
max_torch_compile_bs: int = 32
|
max_torch_compile_bs: int = 32
|
||||||
|
max_cuda_graph_bs: int = 160
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
@@ -624,6 +625,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.max_torch_compile_bs,
|
default=ServerArgs.max_torch_compile_bs,
|
||||||
help="Set the maximum batch size when using torch compile.",
|
help="Set the maximum batch size when using torch compile.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-cuda-graph-bs",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.max_cuda_graph_bs,
|
||||||
|
help="Set the maximum batch size for cuda graph.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torchao-config",
|
"--torchao-config",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
|||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
api_key=cls.api_key,
|
api_key=cls.api_key,
|
||||||
other_args=("--max-total-token", "1024", "--context-len", "8192"),
|
other_args=("--max-total-token", "1024", "--context-len", "8192"),
|
||||||
env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ},
|
env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
|
||||||
return_stdout_stderr=(cls.stdout, cls.stderr),
|
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||||
)
|
)
|
||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|||||||
Reference in New Issue
Block a user