### What this PR does / why we need it?
Change AI Vector core number getting function to glibc ABI free
function. After this PR merged in, there should been no glibc ABI
problems for bump torch version to 2.7.1.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.9.2
- vLLM main:
f59ec35b7f
Signed-off-by: leo-pony <nengjunma@outlook.com>
95 lines
3.1 KiB
Python
95 lines
3.1 KiB
Python
from typing import Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
import torch_npu # noqa: F401
|
|
|
|
import vllm_ascend.platform # noqa: F401
|
|
from vllm_ascend.utils import enable_custom_op
|
|
|
|
enable_custom_op()
|
|
|
|
# Test parameters
|
|
DTYPES = [torch.int32]
|
|
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
|
|
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
|
|
SHAPES = [(3, 4, 3)]
|
|
DEVICES = [f"npu:{0}"]
|
|
SEEDS = [0]
|
|
|
|
|
|
def get_masked_input_and_mask_ref(
|
|
input_: torch.Tensor, org_vocab_start_index: int,
|
|
org_vocab_end_index: int, num_org_vocab_padding: int,
|
|
added_vocab_start_index: int,
|
|
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Reference implementation for verification"""
|
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
|
input_ < org_vocab_end_index)
|
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
|
input_ < added_vocab_end_index)
|
|
added_offset = added_vocab_start_index - (
|
|
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
|
valid_offset = (org_vocab_start_index *
|
|
org_vocab_mask) + (added_offset * added_vocab_mask)
|
|
vocab_mask = org_vocab_mask | added_vocab_mask
|
|
masked_input = vocab_mask * (input_ - valid_offset)
|
|
return masked_input, ~vocab_mask
|
|
|
|
|
|
@pytest.mark.parametrize("shape", SHAPES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("device", DEVICES)
|
|
@pytest.mark.parametrize("seed", SEEDS)
|
|
@torch.inference_mode()
|
|
def test_get_masked_input_and_mask(
|
|
shape: Tuple[int, ...],
|
|
dtype: torch.dtype,
|
|
device: str,
|
|
seed: int,
|
|
) -> None:
|
|
# Set random seed
|
|
torch.manual_seed(seed)
|
|
torch.set_default_device(device)
|
|
|
|
# Generate random input tensor
|
|
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
|
|
|
# Test parameters
|
|
test_case = {
|
|
"org_start": 100,
|
|
"org_end": 200,
|
|
"padding": 0,
|
|
"added_start": 300,
|
|
"added_end": 400,
|
|
}
|
|
|
|
# Get reference result
|
|
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
|
|
input_tensor, test_case["org_start"], test_case["org_end"],
|
|
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
|
|
|
# Get custom op result
|
|
print("input_tensor:", input_tensor)
|
|
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
|
|
input_tensor, test_case["org_start"], test_case["org_end"],
|
|
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
|
|
|
ref_masked_input = ref_masked_input.to(dtype)
|
|
print("custom_masked_input:", custom_masked_input)
|
|
print("ref_masked_input:", ref_masked_input)
|
|
print("custom_mask:", custom_mask)
|
|
print("ref_mask:", ref_mask)
|
|
# Compare results
|
|
torch.testing.assert_close(
|
|
custom_masked_input,
|
|
ref_masked_input,
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
msg=f"Masked input mismatch for case: {test_case}")
|
|
torch.testing.assert_close(custom_mask,
|
|
ref_mask,
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
msg=f"Mask mismatch for case: {test_case}")
|