diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 66a63a4..63b56fd 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -105,20 +105,6 @@ # Future Plan: # Revert it when the related pr is merged in vllm and vllm-ascend. # -# ** File: worker/patch_common/patch_eagle.py ** -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.v1.spec_decode.eagle.prepare_inputs` -# Why: -# We need to use the patched `prepare_input_kernel` in `eagle.prepare_inputs`. -# The mainly reason to overwrite `prepare_input_kernel` is this is a triton -# kernel, ascend is now not support triton kernel. -# How: -# Re-implementation the `prepare_input_kernel` triton kernel by pytorch -# Related PR (if no, explain why): -# - Ascend doesn't support triton -# Future Plan: -# Revert it when the ascend support triton kernel. -# # ** File: worker/patch_common/patch_sampler.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p` @@ -168,4 +154,4 @@ # Related PR (if no, explain why): # This is the problem in vllm-ascend # Future Plan: -# Remove this patch once pytorch 2.7.0 is supported for vllm ascend. +# Remove this patch once pytorch 2.7.0 is supported for vllm ascend. \ No newline at end of file diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 08a4b60..d78b6dc 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -19,7 +19,6 @@ # patch files. import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa -import vllm_ascend.patch.worker.patch_common.patch_eagle # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_eagle.py b/vllm_ascend/patch/worker/patch_common/patch_eagle.py deleted file mode 100644 index f1c2a62..0000000 --- a/vllm_ascend/patch/worker/patch_common/patch_eagle.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import torch -from vllm.v1.spec_decode.eagle import EagleProposer - - -def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - - cu_num_tokens = torch.empty_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() - token_indices = torch.empty( - num_tokens, - dtype=torch.int32, - device=cu_num_tokens.device, - ) - - BLOCK_SIZE = 1024 - prepare_input_pytorch( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) - return cu_num_tokens, token_indices - - -def prepare_input_pytorch(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor, - cu_num_tokens: torch.Tensor, block_size: int): - num_pids = cu_num_tokens.shape[0] - 1 - - for pid in range(num_pids): - start_pos = cu_num_tokens[pid].item() - end_pos = cu_num_tokens[pid + 1].item() - num_tokens = end_pos - start_pos - - index_start = cu_query_lens[pid].item() - num_blocks = (num_tokens + block_size - 1) - - for i in range(num_blocks): - offset = torch.arange(0, - block_size, - dtype=out_ptr.dtype, - device=cu_query_lens.device) - global_indices = start_pos + offset - values = index_start + offset - mask = offset < num_tokens - out_ptr[global_indices[mask]] = values[mask] - - -EagleProposer.prepare_inputs = prepare_inputs