Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
if capture_bs is None:
|
||||
if server_args.speculative_algorithm is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
|
||||
capture_bs = list(range(1, 33)) + range(40, 161, 16)
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
||||
else:
|
||||
# Since speculative decoding requires more cuda graph memory, we
|
||||
# capture less.
|
||||
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
||||
capture_bs = (
|
||||
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
capture_bs += [i * 8 for i in range(21, 33)]
|
||||
capture_bs += list(range(160, 257, 8))
|
||||
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
|
||||
Reference in New Issue
Block a user