Files
enginex-ascend-910-vllm/tests/e2e/singlecard/ops/test_vocabparallelembedding.py
2025-10-14 10:38:28 +08:00

99 lines
3.2 KiB
Python

import gc
from typing import Tuple
import pytest
import torch
import torch_npu # noqa: F401
import vllm_ascend.platform # noqa: F401
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Test parameters
DTYPES = [torch.int32]
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
SHAPES = [(3, 4, 3)]
DEVICES = [f"npu:{0}"]
SEEDS = [0]
def get_masked_input_and_mask_ref(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Reference implementation for verification"""
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
masked_input = vocab_mask * (input_ - valid_offset)
return masked_input, ~vocab_mask
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_get_masked_input_and_mask(
shape: Tuple[int, ...],
dtype: torch.dtype,
device: str,
seed: int,
) -> None:
# Set random seed
torch.manual_seed(seed)
torch.set_default_device(device)
# Generate random input tensor
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
# Test parameters
test_case = {
"org_start": 100,
"org_end": 200,
"padding": 0,
"added_start": 300,
"added_end": 400,
}
# Get reference result
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
# Get custom op result
print("input_tensor:", input_tensor)
custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
ref_masked_input = ref_masked_input.to(dtype)
print("custom_masked_input:", custom_masked_input)
print("ref_masked_input:", ref_masked_input)
print("custom_mask:", custom_mask)
print("ref_mask:", ref_mask)
# Compare results
torch.testing.assert_close(
custom_masked_input,
ref_masked_input,
rtol=1e-5,
atol=1e-5,
msg=f"Masked input mismatch for case: {test_case}")
torch.testing.assert_close(custom_mask,
ref_mask,
rtol=1e-5,
atol=1e-5,
msg=f"Mask mismatch for case: {test_case}")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()