99 lines
3.2 KiB
Python
99 lines
3.2 KiB
Python
import gc
|
|
from typing import Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
import torch_npu # noqa: F401
|
|
|
|
import vllm_ascend.platform # noqa: F401
|
|
from vllm_ascend.utils import enable_custom_op
|
|
|
|
enable_custom_op()
|
|
|
|
# Test parameters
|
|
DTYPES = [torch.int32]
|
|
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
|
|
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
|
|
SHAPES = [(3, 4, 3)]
|
|
DEVICES = [f"npu:{0}"]
|
|
SEEDS = [0]
|
|
|
|
|
|
def get_masked_input_and_mask_ref(
|
|
input_: torch.Tensor, org_vocab_start_index: int,
|
|
org_vocab_end_index: int, num_org_vocab_padding: int,
|
|
added_vocab_start_index: int,
|
|
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Reference implementation for verification"""
|
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
|
input_ < org_vocab_end_index)
|
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
|
input_ < added_vocab_end_index)
|
|
added_offset = added_vocab_start_index - (
|
|
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
|
valid_offset = (org_vocab_start_index *
|
|
org_vocab_mask) + (added_offset * added_vocab_mask)
|
|
vocab_mask = org_vocab_mask | added_vocab_mask
|
|
masked_input = vocab_mask * (input_ - valid_offset)
|
|
return masked_input, ~vocab_mask
|
|
|
|
|
|
@pytest.mark.parametrize("shape", SHAPES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
|
@torch.inference_mode()
|
|
def test_get_masked_input_and_mask(
|
|
shape: Tuple[int, ...],
|
|
dtype: torch.dtype,
|
|
device: str,
|
|
seed: int,
|
|
) -> None:
|
|
# Set random seed
|
|
torch.manual_seed(seed)
|
|
torch.set_default_device(device)
|
|
|
|
# Generate random input tensor
|
|
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
|
|
|
# Test parameters
|
|
test_case = {
|
|
"org_start": 100,
|
|
"org_end": 200,
|
|
"padding": 0,
|
|
"added_start": 300,
|
|
"added_end": 400,
|
|
}
|
|
|
|
# Get reference result
|
|
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
|
|
input_tensor, test_case["org_start"], test_case["org_end"],
|
|
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
|
|
|
# Get custom op result
|
|
print("input_tensor:", input_tensor)
|
|
custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask(
|
|
input_tensor, test_case["org_start"], test_case["org_end"],
|
|
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
|
|
|
ref_masked_input = ref_masked_input.to(dtype)
|
|
print("custom_masked_input:", custom_masked_input)
|
|
print("ref_masked_input:", ref_masked_input)
|
|
print("custom_mask:", custom_mask)
|
|
print("ref_mask:", ref_mask)
|
|
# Compare results
|
|
torch.testing.assert_close(
|
|
custom_masked_input,
|
|
ref_masked_input,
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
msg=f"Masked input mismatch for case: {test_case}")
|
|
torch.testing.assert_close(custom_mask,
|
|
ref_mask,
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
msg=f"Mask mismatch for case: {test_case}")
|
|
gc.collect()
|
|
torch.npu.empty_cache()
|
|
torch.npu.reset_peak_memory_stats()
|