import pytest import torch import torch_mlu from vllm import _mlu_ops as mlu_ops from typing import Tuple @pytest.mark.parametrize("num_seqs, num_queries", [(20, 17), (17, 20), (256, 224)]) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("TILE_SIZE", [8, 64, 256]) @pytest.mark.parametrize("device", ["mlu"]) def test_advance_step(num_seqs, num_queries, block_size, TILE_SIZE, device): if num_seqs < num_queries: pytest.skip( f"Skipping invalid case since num_seqs({num_seqs}) " f"is smaller than num_queries({num_queries})." ) def torch_impl(input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_positions: torch.Tensor, seq_lens: torch.Tensor, slot_mapping: torch.Tensor, block_tables: torch.Tensor, num_seqs: int, num_queries: int, block_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Get updated input_tokens. sampled_token_ids = sampled_token_ids[:num_queries] torch_input_tokens = torch.clone(input_tokens) torch_input_tokens[:num_queries] = sampled_token_ids # Get updated seq_lens. torch_seq_lens = torch.clone(seq_lens) torch_seq_lens[:num_queries] += 1 # Get updated input_positions. torch_input_positions = torch.clone(input_positions) torch_input_positions[:num_queries] = torch_seq_lens[:num_queries] - 1 # Get updated slot_mapping. torch_slot_mapping = torch.clone(slot_mapping) block_index = torch_input_positions[:num_queries] // block_size block_offset = torch_input_positions[:num_queries] % block_size indices = [slice(0, num_queries)] + [0] * (block_tables.ndim - 1) intermediate_block_table = block_tables[tuple(indices)] slot_num = intermediate_block_table * block_size + block_offset torch_slot_mapping[:num_queries] = slot_num return (torch_input_tokens, torch_seq_lens, torch_input_positions, torch_slot_mapping) block_tables_inner_size = 2 input_tokens = torch.zeros(num_seqs, dtype=torch.int64, device=device) sampled_token_ids = torch.arange(num_queries, dtype=torch.int64, device=device) input_positions = torch.empty(num_seqs, dtype=torch.int32, device=device) seq_lens = torch.ones(num_seqs, dtype=torch.int32, device=device) slot_mapping = torch.empty(num_seqs, dtype=torch.int32, device=device) block_tables = torch.arange( num_seqs * block_tables_inner_size, dtype=torch.int32, device=device ).view(num_seqs, block_tables_inner_size) torch_input_tokens, torch_seq_lens, torch_input_positions, torch_slot_mapping = torch_impl( input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, num_seqs, num_queries, block_size ) mlu_ops.advance_step( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids.view(-1, 1), input_positions, seq_lens, slot_mapping, block_tables, TILE_SIZE=TILE_SIZE ) assert torch.allclose(torch_input_tokens, input_tokens) assert torch.allclose(torch_seq_lens, seq_lens) assert torch.allclose(torch_input_positions, input_positions) assert torch.allclose(torch_slot_mapping, slot_mapping)