Simplify prepare_extend_after_decode (#6987)

This commit is contained in:
Lianmin Zheng
2025-06-09 16:39:21 -07:00
committed by GitHub
parent a968c888c0
commit dc0705a504
9 changed files with 140 additions and 176 deletions

View File

@@ -56,6 +56,16 @@ def get_is_capture_mode():
return is_capture_mode
@contextmanager
def model_capture_mode():
global is_capture_mode
is_capture_mode = True
yield
is_capture_mode = False
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
@@ -291,22 +301,13 @@ class CudaGraphRunner:
# Capture
try:
with self.model_capture_mode():
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)
@contextmanager
def model_capture_mode(self):
global is_capture_mode
is_capture_mode = True
yield
is_capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
@@ -650,6 +651,8 @@ class CudaGraphRunner:
topk=self.model_runner.server_args.speculative_eagle_topk,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=None,
seq_lens_cpu=None,
)
return spec_info