diff --git a/CMakeLists.txt b/CMakeLists.txt index a2c3ad2..8d06c75 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,5 +96,3 @@ target_link_libraries( target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib") install(TARGETS vllm_ascend_C vllm_ascend_kernels DESTINATION ${VLLM_ASCEND_INSTALL_PATH}) - - diff --git a/benchmarks/ops/ben_vocabparallelembedding.py b/benchmarks/ops/ben_vocabparallelembedding.py new file mode 100644 index 0000000..e91cfed --- /dev/null +++ b/benchmarks/ops/ben_vocabparallelembedding.py @@ -0,0 +1,144 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +import torch_npu # noqa: F401 +import vllm # noqa: F401 + +import vllm_ascend.platform # noqa: F401 + + +def benchmark_npu(fn, num_iterations=100, num_warmup_iterations=50): + """ + Benchmark function for NPU operations + + Args: + fn: Function to benchmark + num_iterations: Number of timing iterations + num_warmup_iterations: Number of warmup iterations + + Returns: + float: Minimum elapsed time in seconds + """ + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + times = np.zeros(num_iterations + num_warmup_iterations) + + # Run iterations + for i in range(num_warmup_iterations + num_iterations): + with torch.no_grad(): + start.record() + fn() # Execute the function + end.record() + torch.npu.synchronize() + times[i] = start.elapsed_time(end) + + # Remove warmup iterations and convert to seconds + times = times[num_warmup_iterations:] + elapsed_time = np.amin(times) / 1000 + return elapsed_time + + +def get_masked_input_and_mask_ref( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Reference implementation for verification""" + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < + org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + masked_input = vocab_mask * (input_ - valid_offset) + return masked_input, ~vocab_mask + + +DTYPES = [torch.int32] +SHAPES = [(3, 4, 5)] +DEVICES = [f"npu:{0}"] +SEEDS = [0] + + +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_get_masked_input_and_mask( + shape: Tuple[int, ...], + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + # Set random seed and device + torch.manual_seed(seed) + torch.set_default_device(device) + + # Generate random input tensor + input_tensor = torch.randint(0, 1000, shape, dtype=dtype) + + # Test parameters + test_case = { + "org_start": 100, + "org_end": 200, + "padding": 0, + "added_start": 300, + "added_end": 400, + } + + # Define reference function + def ref_fn(): + return get_masked_input_and_mask_ref(input_tensor, + test_case["org_start"], + test_case["org_end"], + test_case["padding"], + test_case["added_start"], + test_case["added_end"]) + + # Define custom function + def custom_fn(): + return torch.ops._C.get_masked_input_and_mask(input_tensor, + test_case["org_start"], + test_case["org_end"], + test_case["padding"], + test_case["added_start"], + test_case["added_end"]) + + # Get results for correctness testing + ref_masked_input, ref_mask = ref_fn() + custom_masked_input, custom_mask = custom_fn() + + # Benchmark both implementations + ref_time = benchmark_npu(ref_fn) + custom_time = benchmark_npu(custom_fn) + + # Print performance results + print("\nPerformance Results:") + print(f"Reference implementation: {ref_time*1000:.3f} ms") + print(f"Custom implementation: {custom_time*1000:.3f} ms") + print(f"Speedup: {ref_time/custom_time:.2f}x") + + # Compare results for correctness + ref_masked_input = ref_masked_input.to(dtype) + print("\nResults comparison:") + print("custom_masked_input:", custom_masked_input) + print("ref_masked_input:", ref_masked_input) + print("custom_mask:", custom_mask) + print("ref_mask:", ref_mask) + torch.testing.assert_close( + custom_masked_input, + ref_masked_input, + rtol=1e-5, + atol=1e-5, + msg=f"Masked input mismatch for case: {test_case}") + torch.testing.assert_close(custom_mask, + ref_mask, + rtol=1e-5, + atol=1e-5, + msg=f"Mask mismatch for case: {test_case}") diff --git a/csrc/kernels/get_masked_input_and_mask_kernel.cpp b/csrc/kernels/get_masked_input_and_mask_kernel.cpp new file mode 100644 index 0000000..47ce826 --- /dev/null +++ b/csrc/kernels/get_masked_input_and_mask_kernel.cpp @@ -0,0 +1,345 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + */ + +#include "kernel_operator.h" +#include "kernel_tensor_impl.h" +#include "kernel_type.h" +#include "types.h" +#include "utils.h" +using vllm_ascend::AccType; + +template +class GetMaskedInputAndMask { +public: + __aicore__ inline GetMaskedInputAndMask() {} + + __aicore__ inline ~GetMaskedInputAndMask() { + pipe.Reset(); + } + + + __aicore__ inline void Init( + __gm__ scalar_t* input, + __gm__ scalar_t* masked_input, + __gm__ bool* mask_out, + const int64_t org_vocab_start_index, + const int64_t org_vocab_end_index, + const int64_t num_org_vocab_padding, + const int64_t added_vocab_start_index, + const int64_t added_vocab_end_index, + const int64_t size) + { + // Initialize basic parameters + input_ = input; + masked_input_ = masked_input; + mask_out_ = mask_out; + org_vocab_start_index_ = org_vocab_start_index; + org_vocab_end_index_ = org_vocab_end_index; + size_ = ((size + 31) / 32) * 32; + added_offset_ = added_vocab_start_index - + (org_vocab_end_index - org_vocab_start_index) - + num_org_vocab_padding; + added_vocab_start_index_ = added_vocab_start_index; + added_vocab_end_index_ = added_vocab_end_index; + + // Initialize global tensors + inputGlobal.SetGlobalBuffer(input); + maskedOutputGlobal.SetGlobalBuffer(masked_input); + maskOutGlobal.SetGlobalBuffer(mask_out); + + // Initialize queues + pipe.InitBuffer(inQueue, 1, size_ * sizeof(scalar_t)); + pipe.InitBuffer(outQueue, 1, size_ * sizeof(scalar_t)); + pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool)); + + // Initialize calculation buffers + pipe.InitBuffer(calc_buf_1, size_ * sizeof(float)); + pipe.InitBuffer(calc_buf_2, size_ * sizeof(float)); + + // Initialize result queues + pipe.InitBuffer(result_ge_que, BUFFER_NUM, size_ * sizeof(float)); + pipe.InitBuffer(result_le_que, BUFFER_NUM, size_ * sizeof(float)); + pipe.InitBuffer(result_org_mask_que, BUFFER_NUM, size_ * sizeof(float)); + pipe.InitBuffer(result_add_mask_que, BUFFER_NUM, size_ * sizeof(float)); + + // Initialize temporary buffers + pipe.InitBuffer(start_buf, size_ * sizeof(float)); + pipe.InitBuffer(end_buf, size_ * sizeof(float)); + pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float)); + pipe.InitBuffer(validOffset_buf, size_ * sizeof(float)); + pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t)); + pipe.InitBuffer(ones_buf_, size_ * sizeof(float)); + } + + __aicore__ inline void Process() + { + CopyIn(); + Compute(); + CopyOut(); + } + +private: + __aicore__ inline void CopyIn() + { + AscendC::LocalTensor inputLocal = inQueue.AllocTensor(); + AscendC::DataCopy(inputLocal, inputGlobal, size_); + inQueue.EnQue(inputLocal); + } + + __aicore__ inline void CompareWithValue( + AscendC::LocalTensor& result, + const AscendC::LocalTensor& input, + const AscendC::LocalTensor& compare_value, + bool is_greater_equal) { + + AscendC::LocalTensor compute_buf = calc_buf_1.Get(); + if (is_greater_equal) { + AscendC::Max(compute_buf, input, compare_value, size_); + AscendC::Sub(compute_buf, compare_value, compute_buf, size_); + } else { + AscendC::Max(compute_buf, input, compare_value, size_); + AscendC::Sub(compute_buf, compute_buf, compare_value, size_); + } + + AscendC::Abs(compute_buf, compute_buf, size_); + AscendC::Mins(compute_buf, compute_buf, MIN_ACCURACY_FP32, size_); + AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_); + AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_); + AscendC::Muls(compute_buf, compute_buf, MAX_MUL_2_FP32, size_); + AscendC::Adds(compute_buf, compute_buf, NEGATIVE_ONE_FP32, size_); + AscendC::Abs(compute_buf, compute_buf, size_); + + AscendC::LocalTensor compute_buf_fp16 = calc_buf_2.Get(); + AscendC::Cast(compute_buf_fp16, compute_buf, AscendC::RoundMode::CAST_NONE, size_); + AscendC::Cast(result, compute_buf_fp16, AscendC::RoundMode::CAST_NONE, size_); + } + + __aicore__ inline void ComputeRangeMask( + AscendC::LocalTensor& range_mask, + const AscendC::LocalTensor& input, + const float start_value, + const float end_value) { + + // Use already initialized buffers + AscendC::LocalTensor start_value_tensor = start_buf.Get(); + AscendC::LocalTensor end_value_tensor = end_buf.Get(); + + AscendC::Duplicate(start_value_tensor, start_value, size_); + AscendC::Duplicate(end_value_tensor, end_value, size_); + + AscendC::LocalTensor ge_result = result_ge_que.AllocTensor(); + AscendC::LocalTensor lt_result = result_le_que.AllocTensor(); + + CompareWithValue(ge_result, start_value_tensor, input, true); + CompareWithValue(lt_result, input, end_value_tensor, false); + + AscendC::And(range_mask, ge_result, lt_result, size_); + } + + __aicore__ inline void Compute() { + AscendC::LocalTensor inputLocal = inQueue.DeQue(); + AscendC::LocalTensor maskedLocal = outQueue.AllocTensor(); + AscendC::LocalTensor maskLocal = maskQueue.AllocTensor(); + + AscendC::LocalTensor inputFloat = inputFloat_buf.Get(); + AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_); + + // Calculate mask for org_vocab range + // org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) + AscendC::LocalTensor orgVocabMask = result_org_mask_que.AllocTensor(); + ComputeRangeMask(orgVocabMask, + inputFloat, + static_cast(org_vocab_start_index_), + static_cast(org_vocab_end_index_)); + + // Calculate mask for added_vocab range + // added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index) + AscendC::LocalTensor addedVocabMask = result_add_mask_que.AllocTensor(); + ComputeRangeMask(addedVocabMask, + inputFloat, + static_cast(added_vocab_start_index_), + static_cast(added_vocab_end_index_)); + + // Calculate validOffset + // valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask) + AscendC::LocalTensor validOffset = validOffset_buf.Get(); + AscendC::LocalTensor constOrgStartIndex = start_buf.Get(); + + AscendC::Duplicate(constOrgStartIndex, float(org_vocab_start_index_), size_); + + AscendC::LocalTensor orgVocabMask_fp16; + AscendC::LocalTensor orgVocabMask_fp32; + AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_); + AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_); + + AscendC::Mul(validOffset, + constOrgStartIndex, + orgVocabMask_fp32, + size_); + + AscendC::LocalTensor addedOffset; + AscendC::LocalTensor addedOffsetTensor = end_buf.Get(); + AscendC::Duplicate(addedOffsetTensor, float(added_offset_), size_); + + AscendC::LocalTensor addedVocabMask_fp16; + AscendC::LocalTensor addedVocabMask_fp32; + AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_); + AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_); + + AscendC::Mul(addedOffset, + addedOffsetTensor, + addedVocabMask_fp32, + size_); + + AscendC::Add(validOffset, validOffset, addedOffset, size_); + + // vocab_mask = org_vocab_mask | added_vocab_mask + AscendC::LocalTensor vocabMask = vocabMask_buf_.Get(); + + AscendC::Or(vocabMask, + orgVocabMask, + addedVocabMask, + size_); + + AscendC::Sub(inputFloat, inputFloat, validOffset, size_); + + // input_ = vocab_mask * (input_ - valid_offset) + AscendC::LocalTensor vocabMask_fp16; + AscendC::LocalTensor vocabMask_fp32; + AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_); + AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_); + + AscendC::LocalTensor inputFloat_fp32; + AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_); + + AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_); + outQueue.EnQue(maskedLocal); + + // ~vocab_mask + AscendC::LocalTensor ones_tensor = ones_buf_.Get(); + AscendC::Duplicate(ones_tensor, (float)1, size_); + AscendC::LocalTensor maskLocal_fp32; + + AscendC::Sub(maskLocal_fp32, + ones_tensor, + vocabMask_fp32, + size_); + + AscendC::LocalTensor maskLocal_fp16; + AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_); + AscendC::Cast(maskLocal, maskLocal_fp16, AscendC::RoundMode::CAST_NONE, size_); + maskQueue.EnQue(maskLocal); + inQueue.FreeTensor(inputLocal); + } + + __aicore__ inline void CopyOut() + { + AscendC::LocalTensor maskedLocal = outQueue.DeQue(); + AscendC::LocalTensor maskLocal = maskQueue.DeQue(); + + AscendC::DataCopy(maskedOutputGlobal, maskedLocal, size_); + AscendC::DataCopy(maskOutGlobal, maskLocal, size_); + + outQueue.FreeTensor(maskedLocal); + maskQueue.FreeTensor(maskLocal); + } + +private: + static constexpr int32_t BUFFER_NUM = 2; + AscendC::TPipe pipe; + AscendC::TQue inQueue; + AscendC::TQue outQueue, maskQueue; + AscendC::GlobalTensor inputGlobal, maskedOutputGlobal; + AscendC::GlobalTensor maskOutGlobal; + AscendC::TBuf calc_buf_1; + AscendC::TBuf calc_buf_2; + AscendC::TQue result_ge_que; + AscendC::TQue result_le_que; + AscendC::TQue result_org_mask_que; + AscendC::TQue result_add_mask_que; + + // Temporary buffers + AscendC::TBuf start_buf; + AscendC::TBuf end_buf; + + // Temporary buffers continued + AscendC::TBuf inputFloat_buf; + AscendC::TBuf validOffset_buf; + AscendC::TBuf vocabMask_buf_; + AscendC::TBuf ones_buf_; + + __gm__ scalar_t *input_, *masked_input_; + __gm__ bool *mask_out_; + int64_t size_; + int64_t org_vocab_start_index_, org_vocab_end_index_; + int64_t added_vocab_start_index_, added_vocab_end_index_; + int64_t added_offset_; + + static constexpr float MIN_ACCURACY_FP32 = 1.1754943508222875e-38; + static constexpr float MAX_MUL_1_FP32 = 1125899906842624; + static constexpr float MAX_MUL_2_FP32 = 67108864; + static constexpr float NEGATIVE_ONE_FP32 = -1.0f; +}; + +extern "C" __global__ __aicore__ void get_masked_input_and_mask_kernel( + __gm__ int32_t* input, + __gm__ int32_t* masked_input, + __gm__ bool* mask_out, + const int64_t org_vocab_start_index, + const int64_t org_vocab_end_index, + const int64_t num_org_vocab_padding, + const int64_t added_vocab_start_index, + const int64_t added_vocab_end_index, + const int64_t size, + const uint32_t loop_cnt, + const uint32_t aiv_num) +{ + { + GetMaskedInputAndMask op{}; + + for (int64_t i = AscendC::GetBlockIdx(); i < loop_cnt; i += aiv_num) { + op.Init(input + i * size/loop_cnt, + masked_input + i * size/loop_cnt, + mask_out + i * size/loop_cnt, + org_vocab_start_index, org_vocab_end_index, + num_org_vocab_padding, added_vocab_start_index, + added_vocab_end_index, size/loop_cnt); + + op.Process(); + } + } // op destructor called here +} + +namespace vllm_ascend { + +void get_masked_input_and_mask_impl( + void* stream, + void* input, + void* masked_input, + void* mask_out, + const int64_t org_vocab_start_index, + const int64_t org_vocab_end_index, + const int64_t num_org_vocab_padding, + const int64_t added_vocab_start_index, + const int64_t added_vocab_end_index, + const int64_t size, + const uint32_t loop_cnt, + const uint32_t aiv_num) +{ + get_masked_input_and_mask_kernel<<>>( + static_cast(input), + static_cast(masked_input), + static_cast(mask_out), + org_vocab_start_index, + org_vocab_end_index, + num_org_vocab_padding, + added_vocab_start_index, + added_vocab_end_index, + size, + loop_cnt, + aiv_num); +} + +} // namespace vllm_ascend + diff --git a/csrc/ops.h b/csrc/ops.h index b921b2b..b1bc602 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -31,6 +31,20 @@ namespace vllm_ascend { const int headSize, const int64_t numTokens, const uint32_t loopCnt, uint32_t aivNum); + extern void get_masked_input_and_mask_impl( + void* stream, + void* input, + void* masked_input, + void* mask_out, + const int64_t org_vocab_start_index, + const int64_t org_vocab_end_index, + const int64_t num_org_vocab_padding, + const int64_t added_vocab_start_index, + const int64_t added_vocab_end_index, + const int64_t size, + const uint32_t loop_cnt, + const uint32_t aiv_num); + torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { if (!tensor.is_privateuseone()) { throw std::runtime_error("Tensor must be on NPU device"); diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index c415438..001a931 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -99,6 +99,112 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T return {query_dst, key_dst}; } +std::tuple get_masked_input_and_mask( + at::Tensor &input, + const int64_t org_vocab_start_index, + const int64_t org_vocab_end_index, + const int64_t num_org_vocab_padding, + const int64_t added_vocab_start_index, + const int64_t added_vocab_end_index) + /* + https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L161-L198 + Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + In order to support various loading methods, we ensure that LoRA-added + embeddings are always at the end of TP-sharded tensors. In other words, + we shard base embeddings and LoRA embeddings separately (both padded), + and place them in the same tensor. + In this example, we will have the original vocab size = 1010, + added vocab size = 16 and padding to 64. Therefore, the total + vocab size with padding will be 1088 (because we first pad 1010 to + 1024, add 16, and then pad to 1088). + Therefore, the tensor format looks like the following: + TP1, rank 0 (no sharding): + |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| + corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | + + TP2, rank 0: + |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| + corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | + TP2, rank 1: + |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| + corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + Parameters: + org_vocab_start_index //base embeddings start + org_vocab_end_index //base embeddings end + num_org_vocab_padding //base embeddings padding + added_vocab_start_index //LoRA embeddings start + added_vocab_end_index //LoRA embeddings end + */ +{ + // Input validation + TORCH_CHECK(input.dim() >= 1, "input must have at least 1 dimension"); + TORCH_CHECK(org_vocab_start_index >= 0, "org_vocab_start_index must be non-negative"); + TORCH_CHECK(org_vocab_end_index >= org_vocab_start_index, "org_vocab_end_index must be greater than org_vocab_start_index"); + TORCH_CHECK(num_org_vocab_padding >= 0, "num_org_vocab_padding must be non-negative"); + TORCH_CHECK(added_vocab_start_index >= org_vocab_end_index, "added_vocab_start_index must be greater than org_vocab_end_index"); + TORCH_CHECK(added_vocab_end_index >= added_vocab_start_index, "added_vocab_end_index must be greater than added_vocab_start_index"); + + // Get total number of elements + int64_t size = input.numel(); + + // Create output tensors + at::Tensor masked_input = at::empty_like(input); + at::Tensor mask = at::empty_like(input).to(at::kBool); + + // Get data pointers + void *input_ptr = input.data_ptr(); + void *masked_input_ptr = masked_input.data_ptr(); + void *mask_ptr = mask.data_ptr(); + + // Get current stream + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + + // Get scalar type + at::ScalarType scalar_type = input.scalar_type(); + + // Create and configure OpCommand + at_npu::native::OpCommand cmd; + cmd.Name("get_masked_input_and_mask"); + cmd.SetCustomHandler([scalar_type, size, stream, + input_ptr, masked_input_ptr, mask_ptr, + org_vocab_start_index, org_vocab_end_index, + num_org_vocab_padding, added_vocab_start_index, + added_vocab_end_index]() -> int { + // Get platform info + fe::PlatFormInfos platform_infos; + int device_id = 0; + fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos); + uint32_t aivNum = platform_infos.GetCoreNumByType("aiv"); + uint32_t loop_cnt = (size + aivNum - 1) / aivNum; + + // Call implementation + get_masked_input_and_mask_impl( + stream, + input_ptr, + masked_input_ptr, + mask_ptr, + org_vocab_start_index, + org_vocab_end_index, + num_org_vocab_padding, + added_vocab_start_index, + added_vocab_end_index, + size, + loop_cnt, + aivNum); + + return 0; + }); + cmd.Run(); + return {masked_input, mask}; +} + void verify_tensor(std::string const& name, at::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { @@ -194,6 +300,16 @@ TORCH_LIBRARY_EXPAND(_C, ops) " Tensor! key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)"); ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding); + + ops.def( + "get_masked_input_and_mask(Tensor input, " + " int org_vocab_start_index, " + " int org_vocab_end_index, " + " int num_org_vocab_padding, " + " int added_vocab_start_index, " + " int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)"); + ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask); + ops.def( "advance_step_flashattn_ascendc(int num_seqs, int num_queries, int block_size," " Tensor! input_tokens, Tensor! sampled_token_ids, Tensor! input_positions," diff --git a/tests/ops/test_vocabparallelembedding.py b/tests/ops/test_vocabparallelembedding.py new file mode 100644 index 0000000..97d6c70 --- /dev/null +++ b/tests/ops/test_vocabparallelembedding.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import pytest +import torch +import torch_npu # noqa: F401 + +import vllm_ascend.platform # noqa: F401 + +# Test parameters +DTYPES = [torch.int32] +#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes +#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes +SHAPES = [(3, 4, 3)] +DEVICES = [f"npu:{0}"] +SEEDS = [0] + + +def get_masked_input_and_mask_ref( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Reference implementation for verification""" + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < + org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + masked_input = vocab_mask * (input_ - valid_offset) + return masked_input, ~vocab_mask + + +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_get_masked_input_and_mask( + shape: Tuple[int, ...], + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + # Set random seed + torch.manual_seed(seed) + torch.set_default_device(device) + + # Generate random input tensor + input_tensor = torch.randint(0, 1000, shape, dtype=dtype) + + # Test parameters + test_case = { + "org_start": 100, + "org_end": 200, + "padding": 0, + "added_start": 300, + "added_end": 400, + } + + # Get reference result + ref_masked_input, ref_mask = get_masked_input_and_mask_ref( + input_tensor, test_case["org_start"], test_case["org_end"], + test_case["padding"], test_case["added_start"], test_case["added_end"]) + + # Get custom op result + print("input_tensor:", input_tensor) + custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask( + input_tensor, test_case["org_start"], test_case["org_end"], + test_case["padding"], test_case["added_start"], test_case["added_end"]) + + ref_masked_input = ref_masked_input.to(dtype) + print("custom_masked_input:", custom_masked_input) + print("ref_masked_input:", ref_masked_input) + print("custom_mask:", custom_mask) + print("ref_mask:", ref_mask) + # Compare results + torch.testing.assert_close( + custom_masked_input, + ref_masked_input, + rtol=1e-5, + atol=1e-5, + msg=f"Masked input mismatch for case: {test_case}") + torch.testing.assert_close(custom_mask, + ref_mask, + rtol=1e-5, + atol=1e-5, + msg=f"Mask mismatch for case: {test_case}")