91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
import pytest
|
|
import torch
|
|
import torch_mlu
|
|
from vllm import _mlu_ops as mlu_ops
|
|
|
|
from typing import Tuple
|
|
|
|
@pytest.mark.parametrize("num_seqs, num_queries", [(20, 17), (17, 20), (256, 224)])
|
|
@pytest.mark.parametrize("block_size", [16])
|
|
@pytest.mark.parametrize("TILE_SIZE", [8, 64, 256])
|
|
@pytest.mark.parametrize("device", ["mlu"])
|
|
def test_advance_step(num_seqs, num_queries, block_size, TILE_SIZE, device):
|
|
if num_seqs < num_queries:
|
|
pytest.skip(
|
|
f"Skipping invalid case since num_seqs({num_seqs}) "
|
|
f"is smaller than num_queries({num_queries})."
|
|
)
|
|
|
|
def torch_impl(input_tokens: torch.Tensor,
|
|
sampled_token_ids: torch.Tensor,
|
|
input_positions: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
num_seqs: int,
|
|
num_queries: int,
|
|
block_size: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# Get updated input_tokens.
|
|
sampled_token_ids = sampled_token_ids[:num_queries]
|
|
torch_input_tokens = torch.clone(input_tokens)
|
|
torch_input_tokens[:num_queries] = sampled_token_ids
|
|
|
|
# Get updated seq_lens.
|
|
torch_seq_lens = torch.clone(seq_lens)
|
|
torch_seq_lens[:num_queries] += 1
|
|
|
|
# Get updated input_positions.
|
|
torch_input_positions = torch.clone(input_positions)
|
|
torch_input_positions[:num_queries] = torch_seq_lens[:num_queries] - 1
|
|
|
|
# Get updated slot_mapping.
|
|
torch_slot_mapping = torch.clone(slot_mapping)
|
|
block_index = torch_input_positions[:num_queries] // block_size
|
|
block_offset = torch_input_positions[:num_queries] % block_size
|
|
indices = [slice(0, num_queries)] + [0] * (block_tables.ndim - 1)
|
|
intermediate_block_table = block_tables[tuple(indices)]
|
|
slot_num = intermediate_block_table * block_size + block_offset
|
|
torch_slot_mapping[:num_queries] = slot_num
|
|
|
|
return (torch_input_tokens, torch_seq_lens, torch_input_positions, torch_slot_mapping)
|
|
|
|
block_tables_inner_size = 2
|
|
input_tokens = torch.zeros(num_seqs, dtype=torch.int64, device=device)
|
|
sampled_token_ids = torch.arange(num_queries, dtype=torch.int64, device=device)
|
|
input_positions = torch.empty(num_seqs, dtype=torch.int32, device=device)
|
|
seq_lens = torch.ones(num_seqs, dtype=torch.int32, device=device)
|
|
slot_mapping = torch.empty(num_seqs, dtype=torch.int32, device=device)
|
|
block_tables = torch.arange(
|
|
num_seqs * block_tables_inner_size,
|
|
dtype=torch.int32,
|
|
device=device
|
|
).view(num_seqs, block_tables_inner_size)
|
|
torch_input_tokens, torch_seq_lens, torch_input_positions, torch_slot_mapping = torch_impl(
|
|
input_tokens,
|
|
sampled_token_ids,
|
|
input_positions,
|
|
seq_lens,
|
|
slot_mapping,
|
|
block_tables,
|
|
num_seqs,
|
|
num_queries,
|
|
block_size
|
|
)
|
|
mlu_ops.advance_step(
|
|
num_seqs,
|
|
num_queries,
|
|
block_size,
|
|
input_tokens,
|
|
sampled_token_ids.view(-1, 1),
|
|
input_positions,
|
|
seq_lens,
|
|
slot_mapping,
|
|
block_tables,
|
|
TILE_SIZE=TILE_SIZE
|
|
)
|
|
assert torch.allclose(torch_input_tokens, input_tokens)
|
|
assert torch.allclose(torch_seq_lens, seq_lens)
|
|
assert torch.allclose(torch_input_positions, input_positions)
|
|
assert torch.allclose(torch_slot_mapping, slot_mapping)
|