[Feat][Spec] Optimize token index calculation in spec decode with Triton kernel (#5356)
### What this PR does / why we need it? Replace multiple PyTorch operations with a fused Triton kernel to determine token indices for sampling during speculative decoding. This reduces kernel launch overhead and memory traffic, improving overall performance on Ascend hardware. --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -0,0 +1,80 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||||
|
prepare_inputs_padded_kernel
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import (get_vectorcore_num,
|
||||||
|
init_device_properties_triton)
|
||||||
|
from vllm_ascend.spec_decode.eagle_proposer import \
|
||||||
|
_PREPARE_INPUTS_BLOCK_SIZE as BLOCK_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs_padded_ref(
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
valid_sampled_tokens_count,
|
||||||
|
query_start_loc,
|
||||||
|
):
|
||||||
|
num_draft_tokens = torch.cat([
|
||||||
|
cu_num_draft_tokens[0:1],
|
||||||
|
cu_num_draft_tokens[1:] - cu_num_draft_tokens[:-1],
|
||||||
|
])
|
||||||
|
|
||||||
|
num_rejected_tokens = torch.where(
|
||||||
|
num_draft_tokens > 0,
|
||||||
|
num_draft_tokens + 1 - valid_sampled_tokens_count,
|
||||||
|
torch.zeros_like(num_draft_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
token_indices_to_sample = query_start_loc[1:] - 1 - num_rejected_tokens
|
||||||
|
|
||||||
|
return token_indices_to_sample.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_reqs", [1, 7, 32, 128, 2048])
|
||||||
|
def test_prepare_inputs_padded(num_reqs):
|
||||||
|
init_device_properties_triton()
|
||||||
|
device = "npu"
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
draft_lens = torch.randint(1,
|
||||||
|
6, (num_reqs, ),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
cu_num_draft_tokens = torch.cumsum(draft_lens, dim=0).to(torch.int32)
|
||||||
|
|
||||||
|
valid_sampled_tokens_count = torch.zeros_like(draft_lens)
|
||||||
|
for i in range(num_reqs):
|
||||||
|
valid_sampled_tokens_count[i] = torch.randint(0, draft_lens[i] + 2,
|
||||||
|
(1, )).item()
|
||||||
|
|
||||||
|
seq_lens = draft_lens + 1
|
||||||
|
query_start_loc = torch.zeros(num_reqs + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
query_start_loc[1:] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
|
||||||
|
# Run PyTorch reference
|
||||||
|
out_ref = prepare_inputs_padded_ref(cu_num_draft_tokens,
|
||||||
|
valid_sampled_tokens_count,
|
||||||
|
query_start_loc)
|
||||||
|
|
||||||
|
# Run Triton kernel
|
||||||
|
out_tri = torch.empty(num_reqs, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
num_blocks_needed = triton.cdiv(num_reqs, BLOCK_SIZE)
|
||||||
|
num_vector_core = get_vectorcore_num()
|
||||||
|
grid_size = min(num_blocks_needed, num_vector_core)
|
||||||
|
grid = (grid_size, )
|
||||||
|
|
||||||
|
prepare_inputs_padded_kernel[grid](
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
valid_sampled_tokens_count,
|
||||||
|
query_start_loc,
|
||||||
|
out_tri,
|
||||||
|
num_reqs,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_tri, out_ref)
|
||||||
0
vllm_ascend/ops/triton/spec_decode/__init__.py
Normal file
0
vllm_ascend/ops/triton/spec_decode/__init__.py
Normal file
68
vllm_ascend/ops/triton/spec_decode/utils.py
Normal file
68
vllm_ascend/ops/triton/spec_decode/utils.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/spec_decode/utils.py
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def prepare_inputs_padded_kernel(
|
||||||
|
cu_num_draft_tokens_ptr, # [num_reqs]
|
||||||
|
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||||
|
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
||||||
|
token_indices_to_sample_ptr, # [num_reqs] (output)
|
||||||
|
num_reqs, # tl.int32
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_programs = tl.num_programs(axis=0)
|
||||||
|
|
||||||
|
# Grid-Stride Loop:
|
||||||
|
block_start_step = num_programs * BLOCK_SIZE
|
||||||
|
|
||||||
|
for block_start in tl.range(pid * BLOCK_SIZE, num_reqs, block_start_step):
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < num_reqs
|
||||||
|
|
||||||
|
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
|
||||||
|
# cumulative sum (first entry is the first value, not zero).
|
||||||
|
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask)
|
||||||
|
|
||||||
|
prev_indices = offsets - 1
|
||||||
|
has_prev = offsets > 0
|
||||||
|
cu_draft_prev = tl.load(
|
||||||
|
cu_num_draft_tokens_ptr + prev_indices,
|
||||||
|
mask=mask & has_prev,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev,
|
||||||
|
cu_draft_curr)
|
||||||
|
|
||||||
|
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets,
|
||||||
|
mask=mask)
|
||||||
|
num_rejected = num_draft_tokens + 1 - valid_count
|
||||||
|
num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0)
|
||||||
|
|
||||||
|
# query_start_loc[req_idx + 1] is the start position of the next request,
|
||||||
|
# which is one past the last token of this request.
|
||||||
|
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1,
|
||||||
|
mask=mask) - 1
|
||||||
|
|
||||||
|
index_to_sample = q_last_tok_idx - num_rejected
|
||||||
|
tl.store(token_indices_to_sample_ptr + offsets,
|
||||||
|
index_to_sample,
|
||||||
|
mask=mask)
|
||||||
@@ -14,6 +14,7 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.models import supports_multimodal
|
from vllm.model_executor.models import supports_multimodal
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
|
from vllm.triton_utils import HAS_TRITON, triton
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@@ -29,6 +30,9 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
|||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
update_attn_params)
|
update_attn_params)
|
||||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||||
|
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||||
|
prepare_inputs_padded_kernel
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||||
from vllm_ascend.utils import shared_expert_dp_enabled
|
from vllm_ascend.utils import shared_expert_dp_enabled
|
||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
@@ -37,6 +41,9 @@ _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
|||||||
|
|
||||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||||
|
|
||||||
|
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||||
|
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||||
|
|
||||||
|
|
||||||
class EagleProposer(VllmEagleProposer):
|
class EagleProposer(VllmEagleProposer):
|
||||||
|
|
||||||
@@ -737,17 +744,51 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
used as padding and filtered out later by `token_indices_to_sample`.
|
used as padding and filtered out later by `token_indices_to_sample`.
|
||||||
No blocking CPU operations should be introduced in this function.
|
No blocking CPU operations should be introduced in this function.
|
||||||
"""
|
"""
|
||||||
num_draft_tokens_gpu = torch.cat([
|
if HAS_TRITON:
|
||||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
device = valid_sampled_tokens_count.device
|
||||||
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
|
||||||
])
|
|
||||||
|
|
||||||
num_rejected_tokens_gpu = torch.where(
|
if num_reqs != spec_decode_metadata.cu_num_draft_tokens.shape[0]:
|
||||||
num_draft_tokens_gpu > 0,
|
# TODO: This is a serious issue and should be taken care of ASAP
|
||||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
# In short, why input_batch.num_reqs != attn_metadata.num_reqs?
|
||||||
torch.zeros_like(num_draft_tokens_gpu),
|
# Previously in #4963, we modified `query_start_loc`, but this
|
||||||
)
|
# problem remains unsolved.
|
||||||
|
num_reqs = spec_decode_metadata.cu_num_draft_tokens.shape[0]
|
||||||
|
|
||||||
|
token_indices_to_sample = torch.empty((num_reqs, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
num_blocks_needed = triton.cdiv(num_reqs,
|
||||||
|
_PREPARE_INPUTS_BLOCK_SIZE)
|
||||||
|
num_vector_core = get_vectorcore_num()
|
||||||
|
grid_size = min(num_blocks_needed, num_vector_core)
|
||||||
|
grid = (grid_size, )
|
||||||
|
|
||||||
|
prepare_inputs_padded_kernel[grid](
|
||||||
|
spec_decode_metadata.cu_num_draft_tokens,
|
||||||
|
valid_sampled_tokens_count,
|
||||||
|
common_attn_metadata.query_start_loc,
|
||||||
|
token_indices_to_sample,
|
||||||
|
num_reqs,
|
||||||
|
BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_draft_tokens_gpu = torch.cat([
|
||||||
|
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||||
|
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||||
|
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
||||||
|
])
|
||||||
|
|
||||||
|
num_rejected_tokens_gpu = torch.where(
|
||||||
|
num_draft_tokens_gpu > 0,
|
||||||
|
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||||
|
torch.zeros_like(num_draft_tokens_gpu),
|
||||||
|
)
|
||||||
|
|
||||||
|
query_start_loc = common_attn_metadata.query_start_loc[
|
||||||
|
1:1 + num_rejected_tokens_gpu.shape[0]]
|
||||||
|
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
|
||||||
|
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
|
||||||
@@ -781,8 +822,4 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
seq_lens=common_attn_metadata.seq_lens,
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
max_seq_len=0)
|
max_seq_len=0)
|
||||||
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc[
|
|
||||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
|
||||||
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
|
|
||||||
|
|
||||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||||
|
|||||||
Reference in New Issue
Block a user