v0.10.1rc1
This commit is contained in:
158
benchmarks/ops/ben_vocabparallelembedding.py
Normal file
158
benchmarks/ops/ben_vocabparallelembedding.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
import vllm # noqa: F401
|
||||
|
||||
import vllm_ascend.platform # noqa: F401
|
||||
|
||||
|
||||
def benchmark_npu(fn, num_iterations=100, num_warmup_iterations=50):
|
||||
"""
|
||||
Benchmark function for NPU operations
|
||||
|
||||
Args:
|
||||
fn: Function to benchmark
|
||||
num_iterations: Number of timing iterations
|
||||
num_warmup_iterations: Number of warmup iterations
|
||||
|
||||
Returns:
|
||||
float: Minimum elapsed time in seconds
|
||||
"""
|
||||
start = torch.npu.Event(enable_timing=True)
|
||||
end = torch.npu.Event(enable_timing=True)
|
||||
times = np.zeros(num_iterations + num_warmup_iterations)
|
||||
|
||||
# Run iterations
|
||||
for i in range(num_warmup_iterations + num_iterations):
|
||||
with torch.no_grad():
|
||||
start.record()
|
||||
fn() # Execute the function
|
||||
end.record()
|
||||
torch.npu.synchronize()
|
||||
times[i] = start.elapsed_time(end)
|
||||
|
||||
# Remove warmup iterations and convert to seconds
|
||||
times = times[num_warmup_iterations:]
|
||||
elapsed_time = np.amin(times) / 1000
|
||||
return elapsed_time
|
||||
|
||||
|
||||
def get_masked_input_and_mask_ref(
|
||||
input_: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
org_vocab_end_index: int,
|
||||
num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference implementation for verification"""
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index
|
||||
)
|
||||
added_offset = (
|
||||
added_vocab_start_index
|
||||
- (org_vocab_end_index - org_vocab_start_index)
|
||||
- num_org_vocab_padding
|
||||
)
|
||||
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
|
||||
added_offset * added_vocab_mask
|
||||
)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
masked_input = vocab_mask * (input_ - valid_offset)
|
||||
return masked_input, ~vocab_mask
|
||||
|
||||
|
||||
DTYPES = [torch.int32]
|
||||
SHAPES = [(3, 4, 5)]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_get_masked_input_and_mask(
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
# Set random seed and device
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Generate random input tensor
|
||||
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
||||
|
||||
# Test parameters
|
||||
test_case = {
|
||||
"org_start": 100,
|
||||
"org_end": 200,
|
||||
"padding": 0,
|
||||
"added_start": 300,
|
||||
"added_end": 400,
|
||||
}
|
||||
|
||||
# Define reference function
|
||||
def ref_fn():
|
||||
return get_masked_input_and_mask_ref(
|
||||
input_tensor,
|
||||
test_case["org_start"],
|
||||
test_case["org_end"],
|
||||
test_case["padding"],
|
||||
test_case["added_start"],
|
||||
test_case["added_end"],
|
||||
)
|
||||
|
||||
# Define custom function
|
||||
def custom_fn():
|
||||
return torch.ops._C.get_masked_input_and_mask(
|
||||
input_tensor,
|
||||
test_case["org_start"],
|
||||
test_case["org_end"],
|
||||
test_case["padding"],
|
||||
test_case["added_start"],
|
||||
test_case["added_end"],
|
||||
)
|
||||
|
||||
# Get results for correctness testing
|
||||
ref_masked_input, ref_mask = ref_fn()
|
||||
custom_masked_input, custom_mask = custom_fn()
|
||||
|
||||
# Benchmark both implementations
|
||||
ref_time = benchmark_npu(ref_fn)
|
||||
custom_time = benchmark_npu(custom_fn)
|
||||
|
||||
# Print performance results
|
||||
print("\nPerformance Results:")
|
||||
print(f"Reference implementation: {ref_time * 1000:.3f} ms")
|
||||
print(f"Custom implementation: {custom_time * 1000:.3f} ms")
|
||||
print(f"Speedup: {ref_time / custom_time:.2f}x")
|
||||
|
||||
# Compare results for correctness
|
||||
ref_masked_input = ref_masked_input.to(dtype)
|
||||
print("\nResults comparison:")
|
||||
print("custom_masked_input:", custom_masked_input)
|
||||
print("ref_masked_input:", ref_masked_input)
|
||||
print("custom_mask:", custom_mask)
|
||||
print("ref_mask:", ref_mask)
|
||||
torch.testing.assert_close(
|
||||
custom_masked_input,
|
||||
ref_masked_input,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Masked input mismatch for case: {test_case}",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
custom_mask,
|
||||
ref_mask,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Mask mismatch for case: {test_case}",
|
||||
)
|
||||
Reference in New Issue
Block a user