diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index d0d89846..b4d4148d 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -438,3 +438,17 @@ # patch Qwen3_5GatedDeltaNet._forward_core to use triton ops like `fused_recurrent_gated_delta_rule`. # Future Plan: # Remove this patch when all ops in _forward_core support both Qwen3_5 and Qwen3Next. +# +# ** 20. File: worker/patch_cudagraph.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.cudagraph_dispatcher.CudagraphDispatcher._create_padded_batch_descriptor` +# Why: +# vllm's FULL mode will cause error, we use a patch to avoid it. +# After that, FULL can be enable now. +# How: +# Dynamically replace the `_create_padded_batch_descriptor` function at runtime, +# and change the condition of if. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/34880 +# Future Plan: +# Remove this patch when vLLM merges the PR. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 2493ecb2..1664c84a 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -43,3 +43,4 @@ import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa import vllm_ascend.patch.worker.patch_draft_quarot # noqa +import vllm_ascend.patch.worker.patch_cudagraph # noqa diff --git a/vllm_ascend/patch/worker/patch_cudagraph.py b/vllm_ascend/patch/worker/patch_cudagraph.py new file mode 100644 index 00000000..bff74fdd --- /dev/null +++ b/vllm_ascend/patch/worker/patch_cudagraph.py @@ -0,0 +1,38 @@ +from vllm.config import CUDAGraphMode +from vllm.forward_context import BatchDescriptor +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + + +def _create_padded_batch_descriptor( + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + num_active_loras: int = 0, +) -> BatchDescriptor: + max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs + uniform_decode_query_len = self.uniform_decode_query_len + num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] + + # FULL mode should not be treated as uniform decode + if ( + uniform_decode + and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL) + and self.cudagraph_mode != CUDAGraphMode.FULL + ): + num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs) + assert num_tokens_padded % uniform_decode_query_len == 0 + else: + uniform_decode = False + num_reqs = min(num_tokens_padded, max_num_seqs) + + return BatchDescriptor( + num_tokens=num_tokens_padded, + num_reqs=num_reqs, + uniform=uniform_decode, + has_lora=has_lora, + num_active_loras=num_active_loras, + ) + + +CudagraphDispatcher._create_padded_batch_descriptor = _create_padded_batch_descriptor diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b821524a..b8e0c2f9 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -89,7 +89,7 @@ class AscendEagleProposer(EagleProposer): self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling self.decode_threshold = 1 + self.num_speculative_tokens - self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 1, dtype=torch.int32) + self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32) self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32) self.attn_mask_builder = AttentionMaskBuilder(self.device) @@ -362,7 +362,9 @@ class AscendEagleProposer(EagleProposer): model_positions = self._get_positions(num_tokens) - batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs + batch_size = max( + num_tokens // (self.num_speculative_tokens + 1), 1 + ) # if not is_profile else self.runner.max_num_reqs if is_profile: batch_size = min(batch_size, self.runner.max_num_reqs)