Files
xc-llm-ascend/tests/ut/distributed/kv_transfer/test_simple_buffer.py
Yikun Jiang 5f0b42e414 [FOLLOWUP] Use base test to avoid patch everwhere (#1634)
### What this PR does / why we need it?
Use base test to avoid patch everwhere.

Followup here: https://github.com/vllm-project/vllm-ascend/pull/1566

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
ut ci passed

- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
2025-07-22 09:03:40 +08:00

72 lines
2.1 KiB
Python

import zlib
from unittest.mock import MagicMock
import torch
from tests.ut.base import TestBase
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
int32_hash)
class MockSimplePipe:
def __init__(self):
self.cluster_id = 0
self.send_tensor = MagicMock()
self.recv_tensor = MagicMock()
self.deallocate_buffer = MagicMock()
class TestSimpleBuffer(TestBase):
def setUp(self):
self.pipe = MockSimplePipe()
self.buffer = SimpleBuffer(self.pipe)
def test_int32_hash(self):
self.assertEqual(int32_hash("test"), zlib.adler32(b"test"))
def test_insert(self):
input_tokens = torch.tensor([1, 2, 3])
roi = torch.tensor([1, 0, 1])
key = torch.randn(2, 3, 4, 5)
value = torch.randn(2, 3, 4, 5)
hidden = torch.randn(3, 6)
self.buffer.num_layers = 2
self.buffer.num_heads = 4
self.buffer.head_size = 5
self.buffer.hidden_size = 6
self.buffer.dtype = torch.float32
self.buffer.insert(input_tokens, roi, key, value, hidden, "req1")
self.pipe.send_tensor.assert_called()
def test_drop_select(self):
input_tokens = torch.tensor([1, 2, 3])
roi = None
self.buffer.num_layers = 2
self.buffer.num_heads = 4
self.buffer.head_size = 5
self.buffer.hidden_size = 6
self.buffer.dtype = torch.float32
self.pipe.recv_tensor.side_effect = [
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
(MagicMock(), torch.randn(1, 3, 6))
]
result = self.buffer.drop_select(input_tokens, roi, "req1")
self.assertEqual(len(result), 4)
self.assertIsInstance(result[0], torch.Tensor)
self.assertIsInstance(result[1], torch.Tensor)
self.assertIsInstance(result[2], torch.Tensor)
self.assertIsNone(result[3])
self.assertEqual(result[0].shape, (2, 3, 4, 5))
def test_close(self):
self.buffer.close()