Simplify prepare_extend_after_decode (#6987)
This commit is contained in:
@@ -56,6 +56,16 @@ def get_is_capture_mode():
|
||||
return is_capture_mode
|
||||
|
||||
|
||||
@contextmanager
|
||||
def model_capture_mode():
|
||||
global is_capture_mode
|
||||
is_capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
is_capture_mode = False
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
@@ -291,22 +301,13 @@ class CudaGraphRunner:
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with self.model_capture_mode():
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def model_capture_mode(self):
|
||||
global is_capture_mode
|
||||
is_capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
is_capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||
@@ -650,6 +651,8 @@ class CudaGraphRunner:
|
||||
topk=self.model_runner.server_args.speculative_eagle_topk,
|
||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
seq_lens_sum=None,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
return spec_info
|
||||
|
||||
Reference in New Issue
Block a user