[Eagle] Remove the greedy branch and some redundant code (#4363)
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
else:
|
||||
capture_bs = list(range(1, 33))
|
||||
# Since speculative decoding requires more cuda graph memory, we
|
||||
# capture less.
|
||||
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
||||
|
||||
if _is_hip:
|
||||
capture_bs += [i * 8 for i in range(21, 33)]
|
||||
@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||
capture_bs = list(
|
||||
sorted(
|
||||
set(
|
||||
capture_bs
|
||||
+ [model_runner.req_to_token_pool.size - 1]
|
||||
+ [model_runner.req_to_token_pool.size]
|
||||
)
|
||||
)
|
||||
)
|
||||
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
|
||||
model_runner.req_to_token_pool.size
|
||||
]
|
||||
|
||||
capture_bs = list(sorted(set(capture_bs)))
|
||||
capture_bs = [
|
||||
bs
|
||||
for bs in capture_bs
|
||||
@@ -508,7 +505,9 @@ class CudaGraphRunner:
|
||||
self.raw_num_token = raw_num_token
|
||||
self.bs = bs
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
||||
def replay(
|
||||
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
||||
) -> LogitsProcessorOutput:
|
||||
if not skip_attn_backend_init:
|
||||
self.replay_prepare(forward_batch)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user