diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_prepare_inputs_padded.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_prepare_inputs_padded.py new file mode 100644 index 00000000..2a84efbd --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_prepare_inputs_padded.py @@ -0,0 +1,80 @@ +import pytest +import torch +from vllm.triton_utils import triton + +from vllm_ascend.ops.triton.spec_decode.utils import \ + prepare_inputs_padded_kernel +from vllm_ascend.ops.triton.triton_utils import (get_vectorcore_num, + init_device_properties_triton) +from vllm_ascend.spec_decode.eagle_proposer import \ + _PREPARE_INPUTS_BLOCK_SIZE as BLOCK_SIZE + + +def prepare_inputs_padded_ref( + cu_num_draft_tokens, + valid_sampled_tokens_count, + query_start_loc, +): + num_draft_tokens = torch.cat([ + cu_num_draft_tokens[0:1], + cu_num_draft_tokens[1:] - cu_num_draft_tokens[:-1], + ]) + + num_rejected_tokens = torch.where( + num_draft_tokens > 0, + num_draft_tokens + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens), + ) + + token_indices_to_sample = query_start_loc[1:] - 1 - num_rejected_tokens + + return token_indices_to_sample.to(torch.int32) + + +@pytest.mark.parametrize("num_reqs", [1, 7, 32, 128, 2048]) +def test_prepare_inputs_padded(num_reqs): + init_device_properties_triton() + device = "npu" + torch.manual_seed(0) + + draft_lens = torch.randint(1, + 6, (num_reqs, ), + device=device, + dtype=torch.int32) + + cu_num_draft_tokens = torch.cumsum(draft_lens, dim=0).to(torch.int32) + + valid_sampled_tokens_count = torch.zeros_like(draft_lens) + for i in range(num_reqs): + valid_sampled_tokens_count[i] = torch.randint(0, draft_lens[i] + 2, + (1, )).item() + + seq_lens = draft_lens + 1 + query_start_loc = torch.zeros(num_reqs + 1, + device=device, + dtype=torch.int32) + query_start_loc[1:] = torch.cumsum(seq_lens, dim=0) + + # Run PyTorch reference + out_ref = prepare_inputs_padded_ref(cu_num_draft_tokens, + valid_sampled_tokens_count, + query_start_loc) + + # Run Triton kernel + out_tri = torch.empty(num_reqs, dtype=torch.int32, device=device) + + num_blocks_needed = triton.cdiv(num_reqs, BLOCK_SIZE) + num_vector_core = get_vectorcore_num() + grid_size = min(num_blocks_needed, num_vector_core) + grid = (grid_size, ) + + prepare_inputs_padded_kernel[grid]( + cu_num_draft_tokens, + valid_sampled_tokens_count, + query_start_loc, + out_tri, + num_reqs, + BLOCK_SIZE=BLOCK_SIZE, + ) + + torch.testing.assert_close(out_tri, out_ref) diff --git a/vllm_ascend/ops/triton/spec_decode/__init__.py b/vllm_ascend/ops/triton/spec_decode/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py new file mode 100644 index 00000000..c0588502 --- /dev/null +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/spec_decode/utils.py + +from vllm.triton_utils import tl, triton + + +@triton.jit +def prepare_inputs_padded_kernel( + cu_num_draft_tokens_ptr, # [num_reqs] + valid_sampled_tokens_count_ptr, # [num_reqs] + query_start_loc_gpu_ptr, # [num_reqs + 1] + token_indices_to_sample_ptr, # [num_reqs] (output) + num_reqs, # tl.int32 + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # Grid-Stride Loop: + block_start_step = num_programs * BLOCK_SIZE + + for block_start in tl.range(pid * BLOCK_SIZE, num_reqs, block_start_step): + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_reqs + + # Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive + # cumulative sum (first entry is the first value, not zero). + cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask) + + prev_indices = offsets - 1 + has_prev = offsets > 0 + cu_draft_prev = tl.load( + cu_num_draft_tokens_ptr + prev_indices, + mask=mask & has_prev, + other=0, + ) + + num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, + cu_draft_curr) + + valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, + mask=mask) + num_rejected = num_draft_tokens + 1 - valid_count + num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0) + + # query_start_loc[req_idx + 1] is the start position of the next request, + # which is one past the last token of this request. + q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, + mask=mask) - 1 + + index_to_sample = q_last_tok_idx - num_rejected + tl.store(token_indices_to_sample_ptr + offsets, + index_to_sample, + mask=mask) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index d1ce1edf..1e45af53 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -14,6 +14,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.triton_utils import HAS_TRITON, triton from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -29,6 +30,9 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, update_attn_params) from vllm_ascend.ops.rotary_embedding import update_cos_sin +from vllm_ascend.ops.triton.spec_decode.utils import \ + prepare_inputs_padded_kernel +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num from vllm_ascend.utils import shared_expert_dp_enabled PADDING_SLOT_ID = -1 @@ -37,6 +41,9 @@ _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn' _FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'} +# Currently we will fix block size to a small one since `num_reqs` can't be too large +_PREPARE_INPUTS_BLOCK_SIZE = 4 + class EagleProposer(VllmEagleProposer): @@ -737,17 +744,51 @@ class EagleProposer(VllmEagleProposer): used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1], - ]) + if HAS_TRITON: + num_reqs = common_attn_metadata.num_reqs + device = valid_sampled_tokens_count.device - num_rejected_tokens_gpu = torch.where( - num_draft_tokens_gpu > 0, - num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu), - ) + if num_reqs != spec_decode_metadata.cu_num_draft_tokens.shape[0]: + # TODO: This is a serious issue and should be taken care of ASAP + # In short, why input_batch.num_reqs != attn_metadata.num_reqs? + # Previously in #4963, we modified `query_start_loc`, but this + # problem remains unsolved. + num_reqs = spec_decode_metadata.cu_num_draft_tokens.shape[0] + + token_indices_to_sample = torch.empty((num_reqs, ), + dtype=torch.int32, + device=device) + + num_blocks_needed = triton.cdiv(num_reqs, + _PREPARE_INPUTS_BLOCK_SIZE) + num_vector_core = get_vectorcore_num() + grid_size = min(num_blocks_needed, num_vector_core) + grid = (grid_size, ) + + prepare_inputs_padded_kernel[grid]( + spec_decode_metadata.cu_num_draft_tokens, + valid_sampled_tokens_count, + common_attn_metadata.query_start_loc, + token_indices_to_sample, + num_reqs, + BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE, + ) + else: + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1], + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu), + ) + + query_start_loc = common_attn_metadata.query_start_loc[ + 1:1 + num_rejected_tokens_gpu.shape[0]] + token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -781,8 +822,4 @@ class EagleProposer(VllmEagleProposer): seq_lens=common_attn_metadata.seq_lens, max_seq_len=0) - query_start_loc = common_attn_metadata.query_start_loc[ - 1:1 + num_rejected_tokens_gpu.shape[0]] - token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu - return spec_common_attn_metadata, token_indices, token_indices_to_sample