[Feat][Spec] Optimize token index calculation in spec decode with Triton kernel (#5356)
### What this PR does / why we need it? Replace multiple PyTorch operations with a fused Triton kernel to determine token indices for sampling during speculative decoding. This reduces kernel launch overhead and memory traffic, improving overall performance on Ascend hardware. --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -29,6 +30,9 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||
prepare_inputs_padded_kernel
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.utils import shared_expert_dp_enabled
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
@@ -37,6 +41,9 @@ _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
|
||||
|
||||
class EagleProposer(VllmEagleProposer):
|
||||
|
||||
@@ -737,17 +744,51 @@ class EagleProposer(VllmEagleProposer):
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_draft_tokens_gpu = torch.cat([
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
||||
])
|
||||
if HAS_TRITON:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
device = valid_sampled_tokens_count.device
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu),
|
||||
)
|
||||
if num_reqs != spec_decode_metadata.cu_num_draft_tokens.shape[0]:
|
||||
# TODO: This is a serious issue and should be taken care of ASAP
|
||||
# In short, why input_batch.num_reqs != attn_metadata.num_reqs?
|
||||
# Previously in #4963, we modified `query_start_loc`, but this
|
||||
# problem remains unsolved.
|
||||
num_reqs = spec_decode_metadata.cu_num_draft_tokens.shape[0]
|
||||
|
||||
token_indices_to_sample = torch.empty((num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
num_blocks_needed = triton.cdiv(num_reqs,
|
||||
_PREPARE_INPUTS_BLOCK_SIZE)
|
||||
num_vector_core = get_vectorcore_num()
|
||||
grid_size = min(num_blocks_needed, num_vector_core)
|
||||
grid = (grid_size, )
|
||||
|
||||
prepare_inputs_padded_kernel[grid](
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
valid_sampled_tokens_count,
|
||||
common_attn_metadata.query_start_loc,
|
||||
token_indices_to_sample,
|
||||
num_reqs,
|
||||
BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE,
|
||||
)
|
||||
else:
|
||||
num_draft_tokens_gpu = torch.cat([
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
||||
])
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu),
|
||||
)
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc[
|
||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
||||
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
@@ -781,8 +822,4 @@ class EagleProposer(VllmEagleProposer):
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
max_seq_len=0)
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc[
|
||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
||||
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
Reference in New Issue
Block a user