Files
enginex-mlu370-vllm/vllm-v0.6.2/tests/kernels/test_advance_step.py
2026-02-04 17:22:39 +08:00

91 lines
3.5 KiB
Python

import pytest
import torch
import torch_mlu
from vllm import _mlu_ops as mlu_ops
from typing import Tuple
@pytest.mark.parametrize("num_seqs, num_queries", [(20, 17), (17, 20), (256, 224)])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("TILE_SIZE", [8, 64, 256])
@pytest.mark.parametrize("device", ["mlu"])
def test_advance_step(num_seqs, num_queries, block_size, TILE_SIZE, device):
if num_seqs < num_queries:
pytest.skip(
f"Skipping invalid case since num_seqs({num_seqs}) "
f"is smaller than num_queries({num_queries})."
)
def torch_impl(input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
num_seqs: int,
num_queries: int,
block_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Get updated input_tokens.
sampled_token_ids = sampled_token_ids[:num_queries]
torch_input_tokens = torch.clone(input_tokens)
torch_input_tokens[:num_queries] = sampled_token_ids
# Get updated seq_lens.
torch_seq_lens = torch.clone(seq_lens)
torch_seq_lens[:num_queries] += 1
# Get updated input_positions.
torch_input_positions = torch.clone(input_positions)
torch_input_positions[:num_queries] = torch_seq_lens[:num_queries] - 1
# Get updated slot_mapping.
torch_slot_mapping = torch.clone(slot_mapping)
block_index = torch_input_positions[:num_queries] // block_size
block_offset = torch_input_positions[:num_queries] % block_size
indices = [slice(0, num_queries)] + [0] * (block_tables.ndim - 1)
intermediate_block_table = block_tables[tuple(indices)]
slot_num = intermediate_block_table * block_size + block_offset
torch_slot_mapping[:num_queries] = slot_num
return (torch_input_tokens, torch_seq_lens, torch_input_positions, torch_slot_mapping)
block_tables_inner_size = 2
input_tokens = torch.zeros(num_seqs, dtype=torch.int64, device=device)
sampled_token_ids = torch.arange(num_queries, dtype=torch.int64, device=device)
input_positions = torch.empty(num_seqs, dtype=torch.int32, device=device)
seq_lens = torch.ones(num_seqs, dtype=torch.int32, device=device)
slot_mapping = torch.empty(num_seqs, dtype=torch.int32, device=device)
block_tables = torch.arange(
num_seqs * block_tables_inner_size,
dtype=torch.int32,
device=device
).view(num_seqs, block_tables_inner_size)
torch_input_tokens, torch_seq_lens, torch_input_positions, torch_slot_mapping = torch_impl(
input_tokens,
sampled_token_ids,
input_positions,
seq_lens,
slot_mapping,
block_tables,
num_seqs,
num_queries,
block_size
)
mlu_ops.advance_step(
num_seqs,
num_queries,
block_size,
input_tokens,
sampled_token_ids.view(-1, 1),
input_positions,
seq_lens,
slot_mapping,
block_tables,
TILE_SIZE=TILE_SIZE
)
assert torch.allclose(torch_input_tokens, input_tokens)
assert torch.allclose(torch_seq_lens, seq_lens)
assert torch.allclose(torch_input_positions, input_positions)
assert torch.allclose(torch_slot_mapping, slot_mapping)