[cherry-pick][Test]repair for test_compute_slot_mapping (#7836)
### What this PR does / why we need it? repair for test_compute_slot_mapping Signed-off-by: ZT-AIA <1028681969@qq.com>
This commit is contained in:
@@ -5,7 +5,6 @@ from vllm.v1.worker.gpu.block_table import _compute_slot_mappings_kernel as \
|
|||||||
ref_compute_slot_mappings_kernel
|
ref_compute_slot_mappings_kernel
|
||||||
from vllm_ascend.worker.v2.block_table import _compute_slot_mappings_kernel as \
|
from vllm_ascend.worker.v2.block_table import _compute_slot_mappings_kernel as \
|
||||||
ascend_compute_slot_mappings_kernel
|
ascend_compute_slot_mappings_kernel
|
||||||
from vllm.v1.worker.gpu.block_table import _load_ptr, _make_ptr_tensor
|
|
||||||
|
|
||||||
def test_compute_slot_mapping_npu_kernel():
|
def test_compute_slot_mapping_npu_kernel():
|
||||||
|
|
||||||
@@ -52,7 +51,7 @@ def test_compute_slot_mapping_npu_kernel():
|
|||||||
for i in range(num_kv_cache_groups):
|
for i in range(num_kv_cache_groups):
|
||||||
block_table = torch.randint(0, 320, (max_num_reqs, max_num_blocks), dtype=torch.int32, device=device)
|
block_table = torch.randint(0, 320, (max_num_reqs, max_num_blocks), dtype=torch.int32, device=device)
|
||||||
block_tables.append(block_table)
|
block_tables.append(block_table)
|
||||||
block_table_ptrs = _make_ptr_tensor(block_tables)
|
block_table_ptrs = torch.tensor([t.data_ptr() for t in block_table], dtype=torch.uint64, device=device)
|
||||||
block_table_strides = torch.tensor([320], dtype=torch.int32, device=device)
|
block_table_strides = torch.tensor([320], dtype=torch.int32, device=device)
|
||||||
|
|
||||||
block_sizes_tensor = torch.tensor([128], dtype=torch.int32, device=device)
|
block_sizes_tensor = torch.tensor([128], dtype=torch.int32, device=device)
|
||||||
|
|||||||
Reference in New Issue
Block a user