[Refactor][Triton] Move reject sample triton kernels into ops/triton (#5324)

### What this PR does / why we need it?
This PR moves reject sample related triton kernels into `ops/triton`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.


- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-12-29 16:15:41 +08:00
committed by GitHub
parent e7e1a7dc05
commit 28b7614322
3 changed files with 403 additions and 356 deletions

View File

@@ -165,14 +165,13 @@ class TestAscendRejectionSampler(TestBase):
# Test Triton kernel path
with patch("vllm_ascend.sample.rejection_sampler.HAS_TRITON", True):
with patch("vllm_ascend.sample.rejection_sampler.expand_kernel"
with patch("vllm_ascend.sample.rejection_sampler.expand_triton"
) as mock_triton:
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
# grid = triton.cdiv(n, BLOCK_SIZE) = triton.cdiv(3, 2) = 2
mock_triton.__getitem__.assert_called_once_with((2, ))
call_args = mock_triton.__getitem__.return_value.call_args[0]
assert (call_args[1] == x).all()
assert (call_args[2] == cu_num_tokens).all()
mock_triton.assert_called_once()
call_args = mock_triton.call_args[0]
assert (call_args[2] == x).all()
assert (call_args[3] == cu_num_tokens).all()
# Run actual function
with patch("vllm_ascend.sample.rejection_sampler.HAS_TRITON", False):