[Refactor][Triton] Move reject sample triton kernels into ops/triton (#5324)
### What this PR does / why we need it?
This PR moves reject sample related triton kernels into `ops/triton`.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed with existing test.
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -165,14 +165,13 @@ class TestAscendRejectionSampler(TestBase):
|
||||
|
||||
# Test Triton kernel path
|
||||
with patch("vllm_ascend.sample.rejection_sampler.HAS_TRITON", True):
|
||||
with patch("vllm_ascend.sample.rejection_sampler.expand_kernel"
|
||||
with patch("vllm_ascend.sample.rejection_sampler.expand_triton"
|
||||
) as mock_triton:
|
||||
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
|
||||
# grid = triton.cdiv(n, BLOCK_SIZE) = triton.cdiv(3, 2) = 2
|
||||
mock_triton.__getitem__.assert_called_once_with((2, ))
|
||||
call_args = mock_triton.__getitem__.return_value.call_args[0]
|
||||
assert (call_args[1] == x).all()
|
||||
assert (call_args[2] == cu_num_tokens).all()
|
||||
mock_triton.assert_called_once()
|
||||
call_args = mock_triton.call_args[0]
|
||||
assert (call_args[2] == x).all()
|
||||
assert (call_args[3] == cu_num_tokens).all()
|
||||
|
||||
# Run actual function
|
||||
with patch("vllm_ascend.sample.rejection_sampler.HAS_TRITON", False):
|
||||
|
||||
Reference in New Issue
Block a user