Files
xc-llm-ascend/tests/e2e/singlecard/ops/test_vocabparallelembedding.py
wangxiyuan fef18b60bc Refactor e2e CI (#2276)
Refactor E2E CI to make it clear and faster
1. remove some uesless e2e test
2. remove some uesless function
3. Make sure all test runs with VLLMRunner to avoid oom error
4. Make sure all ops test end with torch.empty_cache to avoid oom error
5. run the test one by one to avoid resource limit error


- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-09-02 09:02:22 +08:00

99 lines
3.2 KiB
Python

import gc
from typing import Tuple
import pytest
import torch
import torch_npu # noqa: F401
import vllm_ascend.platform # noqa: F401
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Test parameters
DTYPES = [torch.int32]
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
SHAPES = [(3, 4, 3)]
DEVICES = [f"npu:{0}"]
SEEDS = [0]
def get_masked_input_and_mask_ref(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Reference implementation for verification"""
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
masked_input = vocab_mask * (input_ - valid_offset)
return masked_input, ~vocab_mask
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_get_masked_input_and_mask(
shape: Tuple[int, ...],
dtype: torch.dtype,
device: str,
seed: int,
) -> None:
# Set random seed
torch.manual_seed(seed)
torch.set_default_device(device)
# Generate random input tensor
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
# Test parameters
test_case = {
"org_start": 100,
"org_end": 200,
"padding": 0,
"added_start": 300,
"added_end": 400,
}
# Get reference result
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
# Get custom op result
print("input_tensor:", input_tensor)
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])
ref_masked_input = ref_masked_input.to(dtype)
print("custom_masked_input:", custom_masked_input)
print("ref_masked_input:", ref_masked_input)
print("custom_mask:", custom_mask)
print("ref_mask:", ref_mask)
# Compare results
torch.testing.assert_close(
custom_masked_input,
ref_masked_input,
rtol=1e-5,
atol=1e-5,
msg=f"Masked input mismatch for case: {test_case}")
torch.testing.assert_close(custom_mask,
ref_mask,
rtol=1e-5,
atol=1e-5,
msg=f"Mask mismatch for case: {test_case}")
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()