[Feat][Spec] Optimize token index calculation in spec decode with Triton kernel (#5356)

### What this PR does / why we need it?
Replace multiple PyTorch operations with a fused Triton kernel to
determine token indices for sampling during speculative decoding. This
reduces kernel launch overhead and memory traffic, improving overall
performance on Ascend hardware.

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2026-01-05 16:51:29 +08:00
committed by GitHub
parent 8ffe3f5d78
commit 755caeb06e
4 changed files with 199 additions and 14 deletions