[Refactor]Refactor sampler (#2050)
Refactor Sampler implementation from patch way to inherit from vLLM
Sampler interface.
Next step: Make the op `TopKTopPSampler` in vLLM support custom ops
register mechanism
- vLLM version: v0.10.0
- vLLM main:
61a6905ab0
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
32
tests/ut/sample/test_sampler.py
Normal file
32
tests/ut/sample/test_sampler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
|
||||
|
||||
|
||||
class TestAscendSampler(TestBase):
|
||||
|
||||
def test_init_with_raw_logprobs(self):
|
||||
sampler = AscendSampler(logprobs_mode="raw_logprobs")
|
||||
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
|
||||
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
|
||||
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
|
||||
|
||||
|
||||
class TestAscendTopKTopPSampler(TestBase):
|
||||
|
||||
@mock.patch("torch_npu.npu_top_k_top_p")
|
||||
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
|
||||
mock_npu_op.return_value = (torch.randn(1, 3))
|
||||
sampler = AscendTopKTopPSampler()
|
||||
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
k = torch.tensor([2])
|
||||
p = torch.tensor([0.9])
|
||||
generators = {0: torch.Generator()}
|
||||
generators[0].manual_seed(42)
|
||||
|
||||
sampler.forward_native(logits, generators, k, p)
|
||||
mock_npu_op.assert_called_once_with(logits, p, k)
|
||||
Reference in New Issue
Block a user