add custom ascendc kernel vocabparallelembedding (#796)
This PR add custom ascendc kernel vocabparallelembedding support in vllm-ascend, related CMakeLists and setuptools is also added in this PR. pytest -s benchmarks/ops/ben_vocabparallelembedding.py pytest -s tests/ops/test_vocabparallelembedding.py --------- Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -96,5 +96,3 @@ target_link_libraries(
|
||||
target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
|
||||
|
||||
install(TARGETS vllm_ascend_C vllm_ascend_kernels DESTINATION ${VLLM_ASCEND_INSTALL_PATH})
|
||||
|
||||
|
||||
|
||||
144
benchmarks/ops/ben_vocabparallelembedding.py
Normal file
144
benchmarks/ops/ben_vocabparallelembedding.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
import vllm # noqa: F401
|
||||
|
||||
import vllm_ascend.platform # noqa: F401
|
||||
|
||||
|
||||
def benchmark_npu(fn, num_iterations=100, num_warmup_iterations=50):
|
||||
"""
|
||||
Benchmark function for NPU operations
|
||||
|
||||
Args:
|
||||
fn: Function to benchmark
|
||||
num_iterations: Number of timing iterations
|
||||
num_warmup_iterations: Number of warmup iterations
|
||||
|
||||
Returns:
|
||||
float: Minimum elapsed time in seconds
|
||||
"""
|
||||
start = torch.npu.Event(enable_timing=True)
|
||||
end = torch.npu.Event(enable_timing=True)
|
||||
times = np.zeros(num_iterations + num_warmup_iterations)
|
||||
|
||||
# Run iterations
|
||||
for i in range(num_warmup_iterations + num_iterations):
|
||||
with torch.no_grad():
|
||||
start.record()
|
||||
fn() # Execute the function
|
||||
end.record()
|
||||
torch.npu.synchronize()
|
||||
times[i] = start.elapsed_time(end)
|
||||
|
||||
# Remove warmup iterations and convert to seconds
|
||||
times = times[num_warmup_iterations:]
|
||||
elapsed_time = np.amin(times) / 1000
|
||||
return elapsed_time
|
||||
|
||||
|
||||
def get_masked_input_and_mask_ref(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference implementation for verification"""
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||
org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
masked_input = vocab_mask * (input_ - valid_offset)
|
||||
return masked_input, ~vocab_mask
|
||||
|
||||
|
||||
DTYPES = [torch.int32]
|
||||
SHAPES = [(3, 4, 5)]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_get_masked_input_and_mask(
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
# Set random seed and device
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Generate random input tensor
|
||||
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
||||
|
||||
# Test parameters
|
||||
test_case = {
|
||||
"org_start": 100,
|
||||
"org_end": 200,
|
||||
"padding": 0,
|
||||
"added_start": 300,
|
||||
"added_end": 400,
|
||||
}
|
||||
|
||||
# Define reference function
|
||||
def ref_fn():
|
||||
return get_masked_input_and_mask_ref(input_tensor,
|
||||
test_case["org_start"],
|
||||
test_case["org_end"],
|
||||
test_case["padding"],
|
||||
test_case["added_start"],
|
||||
test_case["added_end"])
|
||||
|
||||
# Define custom function
|
||||
def custom_fn():
|
||||
return torch.ops._C.get_masked_input_and_mask(input_tensor,
|
||||
test_case["org_start"],
|
||||
test_case["org_end"],
|
||||
test_case["padding"],
|
||||
test_case["added_start"],
|
||||
test_case["added_end"])
|
||||
|
||||
# Get results for correctness testing
|
||||
ref_masked_input, ref_mask = ref_fn()
|
||||
custom_masked_input, custom_mask = custom_fn()
|
||||
|
||||
# Benchmark both implementations
|
||||
ref_time = benchmark_npu(ref_fn)
|
||||
custom_time = benchmark_npu(custom_fn)
|
||||
|
||||
# Print performance results
|
||||
print("\nPerformance Results:")
|
||||
print(f"Reference implementation: {ref_time*1000:.3f} ms")
|
||||
print(f"Custom implementation: {custom_time*1000:.3f} ms")
|
||||
print(f"Speedup: {ref_time/custom_time:.2f}x")
|
||||
|
||||
# Compare results for correctness
|
||||
ref_masked_input = ref_masked_input.to(dtype)
|
||||
print("\nResults comparison:")
|
||||
print("custom_masked_input:", custom_masked_input)
|
||||
print("ref_masked_input:", ref_masked_input)
|
||||
print("custom_mask:", custom_mask)
|
||||
print("ref_mask:", ref_mask)
|
||||
torch.testing.assert_close(
|
||||
custom_masked_input,
|
||||
ref_masked_input,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Masked input mismatch for case: {test_case}")
|
||||
torch.testing.assert_close(custom_mask,
|
||||
ref_mask,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Mask mismatch for case: {test_case}")
|
||||
345
csrc/kernels/get_masked_input_and_mask_kernel.cpp
Normal file
345
csrc/kernels/get_masked_input_and_mask_kernel.cpp
Normal file
@@ -0,0 +1,345 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tensor_impl.h"
|
||||
#include "kernel_type.h"
|
||||
#include "types.h"
|
||||
#include "utils.h"
|
||||
using vllm_ascend::AccType;
|
||||
|
||||
template<typename scalar_t>
|
||||
class GetMaskedInputAndMask {
|
||||
public:
|
||||
__aicore__ inline GetMaskedInputAndMask() {}
|
||||
|
||||
__aicore__ inline ~GetMaskedInputAndMask() {
|
||||
pipe.Reset();
|
||||
}
|
||||
|
||||
|
||||
__aicore__ inline void Init(
|
||||
__gm__ scalar_t* input,
|
||||
__gm__ scalar_t* masked_input,
|
||||
__gm__ bool* mask_out,
|
||||
const int64_t org_vocab_start_index,
|
||||
const int64_t org_vocab_end_index,
|
||||
const int64_t num_org_vocab_padding,
|
||||
const int64_t added_vocab_start_index,
|
||||
const int64_t added_vocab_end_index,
|
||||
const int64_t size)
|
||||
{
|
||||
// Initialize basic parameters
|
||||
input_ = input;
|
||||
masked_input_ = masked_input;
|
||||
mask_out_ = mask_out;
|
||||
org_vocab_start_index_ = org_vocab_start_index;
|
||||
org_vocab_end_index_ = org_vocab_end_index;
|
||||
size_ = ((size + 31) / 32) * 32;
|
||||
added_offset_ = added_vocab_start_index -
|
||||
(org_vocab_end_index - org_vocab_start_index) -
|
||||
num_org_vocab_padding;
|
||||
added_vocab_start_index_ = added_vocab_start_index;
|
||||
added_vocab_end_index_ = added_vocab_end_index;
|
||||
|
||||
// Initialize global tensors
|
||||
inputGlobal.SetGlobalBuffer(input);
|
||||
maskedOutputGlobal.SetGlobalBuffer(masked_input);
|
||||
maskOutGlobal.SetGlobalBuffer(mask_out);
|
||||
|
||||
// Initialize queues
|
||||
pipe.InitBuffer(inQueue, 1, size_ * sizeof(scalar_t));
|
||||
pipe.InitBuffer(outQueue, 1, size_ * sizeof(scalar_t));
|
||||
pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool));
|
||||
|
||||
// Initialize calculation buffers
|
||||
pipe.InitBuffer(calc_buf_1, size_ * sizeof(float));
|
||||
pipe.InitBuffer(calc_buf_2, size_ * sizeof(float));
|
||||
|
||||
// Initialize result queues
|
||||
pipe.InitBuffer(result_ge_que, BUFFER_NUM, size_ * sizeof(float));
|
||||
pipe.InitBuffer(result_le_que, BUFFER_NUM, size_ * sizeof(float));
|
||||
pipe.InitBuffer(result_org_mask_que, BUFFER_NUM, size_ * sizeof(float));
|
||||
pipe.InitBuffer(result_add_mask_que, BUFFER_NUM, size_ * sizeof(float));
|
||||
|
||||
// Initialize temporary buffers
|
||||
pipe.InitBuffer(start_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(end_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(validOffset_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t));
|
||||
pipe.InitBuffer(ones_buf_, size_ * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
CopyIn();
|
||||
Compute();
|
||||
CopyOut();
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyIn()
|
||||
{
|
||||
AscendC::LocalTensor<scalar_t> inputLocal = inQueue.AllocTensor<scalar_t>();
|
||||
AscendC::DataCopy(inputLocal, inputGlobal, size_);
|
||||
inQueue.EnQue(inputLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CompareWithValue(
|
||||
AscendC::LocalTensor<int8_t>& result,
|
||||
const AscendC::LocalTensor<float>& input,
|
||||
const AscendC::LocalTensor<float>& compare_value,
|
||||
bool is_greater_equal) {
|
||||
|
||||
AscendC::LocalTensor<float> compute_buf = calc_buf_1.Get<float>();
|
||||
if (is_greater_equal) {
|
||||
AscendC::Max(compute_buf, input, compare_value, size_);
|
||||
AscendC::Sub(compute_buf, compare_value, compute_buf, size_);
|
||||
} else {
|
||||
AscendC::Max(compute_buf, input, compare_value, size_);
|
||||
AscendC::Sub(compute_buf, compute_buf, compare_value, size_);
|
||||
}
|
||||
|
||||
AscendC::Abs(compute_buf, compute_buf, size_);
|
||||
AscendC::Mins(compute_buf, compute_buf, MIN_ACCURACY_FP32, size_);
|
||||
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_);
|
||||
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_);
|
||||
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_2_FP32, size_);
|
||||
AscendC::Adds(compute_buf, compute_buf, NEGATIVE_ONE_FP32, size_);
|
||||
AscendC::Abs(compute_buf, compute_buf, size_);
|
||||
|
||||
AscendC::LocalTensor<half> compute_buf_fp16 = calc_buf_2.Get<half>();
|
||||
AscendC::Cast(compute_buf_fp16, compute_buf, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(result, compute_buf_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeRangeMask(
|
||||
AscendC::LocalTensor<int8_t>& range_mask,
|
||||
const AscendC::LocalTensor<float>& input,
|
||||
const float start_value,
|
||||
const float end_value) {
|
||||
|
||||
// Use already initialized buffers
|
||||
AscendC::LocalTensor<float> start_value_tensor = start_buf.Get<float>();
|
||||
AscendC::LocalTensor<float> end_value_tensor = end_buf.Get<float>();
|
||||
|
||||
AscendC::Duplicate(start_value_tensor, start_value, size_);
|
||||
AscendC::Duplicate(end_value_tensor, end_value, size_);
|
||||
|
||||
AscendC::LocalTensor<int8_t> ge_result = result_ge_que.AllocTensor<int8_t>();
|
||||
AscendC::LocalTensor<int8_t> lt_result = result_le_que.AllocTensor<int8_t>();
|
||||
|
||||
CompareWithValue(ge_result, start_value_tensor, input, true);
|
||||
CompareWithValue(lt_result, input, end_value_tensor, false);
|
||||
|
||||
AscendC::And(range_mask, ge_result, lt_result, size_);
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute() {
|
||||
AscendC::LocalTensor<scalar_t> inputLocal = inQueue.DeQue<scalar_t>();
|
||||
AscendC::LocalTensor<scalar_t> maskedLocal = outQueue.AllocTensor<scalar_t>();
|
||||
AscendC::LocalTensor<int8_t> maskLocal = maskQueue.AllocTensor<int8_t>();
|
||||
|
||||
AscendC::LocalTensor<float> inputFloat = inputFloat_buf.Get<float>();
|
||||
AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
// Calculate mask for org_vocab range
|
||||
// org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
||||
AscendC::LocalTensor<int8_t> orgVocabMask = result_org_mask_que.AllocTensor<int8_t>();
|
||||
ComputeRangeMask(orgVocabMask,
|
||||
inputFloat,
|
||||
static_cast<float>(org_vocab_start_index_),
|
||||
static_cast<float>(org_vocab_end_index_));
|
||||
|
||||
// Calculate mask for added_vocab range
|
||||
// added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index)
|
||||
AscendC::LocalTensor<int8_t> addedVocabMask = result_add_mask_que.AllocTensor<int8_t>();
|
||||
ComputeRangeMask(addedVocabMask,
|
||||
inputFloat,
|
||||
static_cast<float>(added_vocab_start_index_),
|
||||
static_cast<float>(added_vocab_end_index_));
|
||||
|
||||
// Calculate validOffset
|
||||
// valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
AscendC::LocalTensor<float> validOffset = validOffset_buf.Get<float>();
|
||||
AscendC::LocalTensor<float> constOrgStartIndex = start_buf.Get<float>();
|
||||
|
||||
AscendC::Duplicate(constOrgStartIndex, float(org_vocab_start_index_), size_);
|
||||
|
||||
AscendC::LocalTensor<half> orgVocabMask_fp16;
|
||||
AscendC::LocalTensor<float> orgVocabMask_fp32;
|
||||
AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::Mul(validOffset,
|
||||
constOrgStartIndex,
|
||||
orgVocabMask_fp32,
|
||||
size_);
|
||||
|
||||
AscendC::LocalTensor<float> addedOffset;
|
||||
AscendC::LocalTensor<float> addedOffsetTensor = end_buf.Get<float>();
|
||||
AscendC::Duplicate(addedOffsetTensor, float(added_offset_), size_);
|
||||
|
||||
AscendC::LocalTensor<half> addedVocabMask_fp16;
|
||||
AscendC::LocalTensor<float> addedVocabMask_fp32;
|
||||
AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::Mul(addedOffset,
|
||||
addedOffsetTensor,
|
||||
addedVocabMask_fp32,
|
||||
size_);
|
||||
|
||||
AscendC::Add(validOffset, validOffset, addedOffset, size_);
|
||||
|
||||
// vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
AscendC::LocalTensor<int8_t> vocabMask = vocabMask_buf_.Get<int8_t>();
|
||||
|
||||
AscendC::Or(vocabMask,
|
||||
orgVocabMask,
|
||||
addedVocabMask,
|
||||
size_);
|
||||
|
||||
AscendC::Sub(inputFloat, inputFloat, validOffset, size_);
|
||||
|
||||
// input_ = vocab_mask * (input_ - valid_offset)
|
||||
AscendC::LocalTensor<half> vocabMask_fp16;
|
||||
AscendC::LocalTensor<float> vocabMask_fp32;
|
||||
AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::LocalTensor<float> inputFloat_fp32;
|
||||
AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_);
|
||||
|
||||
AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_);
|
||||
outQueue.EnQue(maskedLocal);
|
||||
|
||||
// ~vocab_mask
|
||||
AscendC::LocalTensor<float> ones_tensor = ones_buf_.Get<float>();
|
||||
AscendC::Duplicate(ones_tensor, (float)1, size_);
|
||||
AscendC::LocalTensor<float> maskLocal_fp32;
|
||||
|
||||
AscendC::Sub(maskLocal_fp32,
|
||||
ones_tensor,
|
||||
vocabMask_fp32,
|
||||
size_);
|
||||
|
||||
AscendC::LocalTensor<half> maskLocal_fp16;
|
||||
AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(maskLocal, maskLocal_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
maskQueue.EnQue(maskLocal);
|
||||
inQueue.FreeTensor(inputLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOut()
|
||||
{
|
||||
AscendC::LocalTensor<scalar_t> maskedLocal = outQueue.DeQue<scalar_t>();
|
||||
AscendC::LocalTensor<bool> maskLocal = maskQueue.DeQue<bool>();
|
||||
|
||||
AscendC::DataCopy(maskedOutputGlobal, maskedLocal, size_);
|
||||
AscendC::DataCopy(maskOutGlobal, maskLocal, size_);
|
||||
|
||||
outQueue.FreeTensor(maskedLocal);
|
||||
maskQueue.FreeTensor(maskLocal);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr int32_t BUFFER_NUM = 2;
|
||||
AscendC::TPipe pipe;
|
||||
AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueue;
|
||||
AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue, maskQueue;
|
||||
AscendC::GlobalTensor<scalar_t> inputGlobal, maskedOutputGlobal;
|
||||
AscendC::GlobalTensor<bool> maskOutGlobal;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> calc_buf_1;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> calc_buf_2;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_ge_que;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_le_que;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_org_mask_que;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_add_mask_que;
|
||||
|
||||
// Temporary buffers
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> start_buf;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> end_buf;
|
||||
|
||||
// Temporary buffers continued
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> inputFloat_buf;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> validOffset_buf;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> vocabMask_buf_;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> ones_buf_;
|
||||
|
||||
__gm__ scalar_t *input_, *masked_input_;
|
||||
__gm__ bool *mask_out_;
|
||||
int64_t size_;
|
||||
int64_t org_vocab_start_index_, org_vocab_end_index_;
|
||||
int64_t added_vocab_start_index_, added_vocab_end_index_;
|
||||
int64_t added_offset_;
|
||||
|
||||
static constexpr float MIN_ACCURACY_FP32 = 1.1754943508222875e-38;
|
||||
static constexpr float MAX_MUL_1_FP32 = 1125899906842624;
|
||||
static constexpr float MAX_MUL_2_FP32 = 67108864;
|
||||
static constexpr float NEGATIVE_ONE_FP32 = -1.0f;
|
||||
};
|
||||
|
||||
extern "C" __global__ __aicore__ void get_masked_input_and_mask_kernel(
|
||||
__gm__ int32_t* input,
|
||||
__gm__ int32_t* masked_input,
|
||||
__gm__ bool* mask_out,
|
||||
const int64_t org_vocab_start_index,
|
||||
const int64_t org_vocab_end_index,
|
||||
const int64_t num_org_vocab_padding,
|
||||
const int64_t added_vocab_start_index,
|
||||
const int64_t added_vocab_end_index,
|
||||
const int64_t size,
|
||||
const uint32_t loop_cnt,
|
||||
const uint32_t aiv_num)
|
||||
{
|
||||
{
|
||||
GetMaskedInputAndMask<int32_t> op{};
|
||||
|
||||
for (int64_t i = AscendC::GetBlockIdx(); i < loop_cnt; i += aiv_num) {
|
||||
op.Init(input + i * size/loop_cnt,
|
||||
masked_input + i * size/loop_cnt,
|
||||
mask_out + i * size/loop_cnt,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
num_org_vocab_padding, added_vocab_start_index,
|
||||
added_vocab_end_index, size/loop_cnt);
|
||||
|
||||
op.Process();
|
||||
}
|
||||
} // op destructor called here
|
||||
}
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
void get_masked_input_and_mask_impl(
|
||||
void* stream,
|
||||
void* input,
|
||||
void* masked_input,
|
||||
void* mask_out,
|
||||
const int64_t org_vocab_start_index,
|
||||
const int64_t org_vocab_end_index,
|
||||
const int64_t num_org_vocab_padding,
|
||||
const int64_t added_vocab_start_index,
|
||||
const int64_t added_vocab_end_index,
|
||||
const int64_t size,
|
||||
const uint32_t loop_cnt,
|
||||
const uint32_t aiv_num)
|
||||
{
|
||||
get_masked_input_and_mask_kernel<<<aiv_num, nullptr, stream>>>(
|
||||
static_cast<int32_t*>(input),
|
||||
static_cast<int32_t*>(masked_input),
|
||||
static_cast<bool*>(mask_out),
|
||||
org_vocab_start_index,
|
||||
org_vocab_end_index,
|
||||
num_org_vocab_padding,
|
||||
added_vocab_start_index,
|
||||
added_vocab_end_index,
|
||||
size,
|
||||
loop_cnt,
|
||||
aiv_num);
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
14
csrc/ops.h
14
csrc/ops.h
@@ -31,6 +31,20 @@ namespace vllm_ascend {
|
||||
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
|
||||
uint32_t aivNum);
|
||||
|
||||
extern void get_masked_input_and_mask_impl(
|
||||
void* stream,
|
||||
void* input,
|
||||
void* masked_input,
|
||||
void* mask_out,
|
||||
const int64_t org_vocab_start_index,
|
||||
const int64_t org_vocab_end_index,
|
||||
const int64_t num_org_vocab_padding,
|
||||
const int64_t added_vocab_start_index,
|
||||
const int64_t added_vocab_end_index,
|
||||
const int64_t size,
|
||||
const uint32_t loop_cnt,
|
||||
const uint32_t aiv_num);
|
||||
|
||||
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
|
||||
if (!tensor.is_privateuseone()) {
|
||||
throw std::runtime_error("Tensor must be on NPU device");
|
||||
|
||||
@@ -99,6 +99,112 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
||||
return {query_dst, key_dst};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
at::Tensor &input,
|
||||
const int64_t org_vocab_start_index,
|
||||
const int64_t org_vocab_end_index,
|
||||
const int64_t num_org_vocab_padding,
|
||||
const int64_t added_vocab_start_index,
|
||||
const int64_t added_vocab_end_index)
|
||||
/*
|
||||
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L161-L198
|
||||
Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
In order to support various loading methods, we ensure that LoRA-added
|
||||
embeddings are always at the end of TP-sharded tensors. In other words,
|
||||
we shard base embeddings and LoRA embeddings separately (both padded),
|
||||
and place them in the same tensor.
|
||||
In this example, we will have the original vocab size = 1010,
|
||||
added vocab size = 16 and padding to 64. Therefore, the total
|
||||
vocab size with padding will be 1088 (because we first pad 1010 to
|
||||
1024, add 16, and then pad to 1088).
|
||||
Therefore, the tensor format looks like the following:
|
||||
TP1, rank 0 (no sharding):
|
||||
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
||||
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
||||
|
||||
TP2, rank 0:
|
||||
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
||||
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
|
||||
TP2, rank 1:
|
||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
||||
Parameters:
|
||||
org_vocab_start_index //base embeddings start
|
||||
org_vocab_end_index //base embeddings end
|
||||
num_org_vocab_padding //base embeddings padding
|
||||
added_vocab_start_index //LoRA embeddings start
|
||||
added_vocab_end_index //LoRA embeddings end
|
||||
*/
|
||||
{
|
||||
// Input validation
|
||||
TORCH_CHECK(input.dim() >= 1, "input must have at least 1 dimension");
|
||||
TORCH_CHECK(org_vocab_start_index >= 0, "org_vocab_start_index must be non-negative");
|
||||
TORCH_CHECK(org_vocab_end_index >= org_vocab_start_index, "org_vocab_end_index must be greater than org_vocab_start_index");
|
||||
TORCH_CHECK(num_org_vocab_padding >= 0, "num_org_vocab_padding must be non-negative");
|
||||
TORCH_CHECK(added_vocab_start_index >= org_vocab_end_index, "added_vocab_start_index must be greater than org_vocab_end_index");
|
||||
TORCH_CHECK(added_vocab_end_index >= added_vocab_start_index, "added_vocab_end_index must be greater than added_vocab_start_index");
|
||||
|
||||
// Get total number of elements
|
||||
int64_t size = input.numel();
|
||||
|
||||
// Create output tensors
|
||||
at::Tensor masked_input = at::empty_like(input);
|
||||
at::Tensor mask = at::empty_like(input).to(at::kBool);
|
||||
|
||||
// Get data pointers
|
||||
void *input_ptr = input.data_ptr();
|
||||
void *masked_input_ptr = masked_input.data_ptr();
|
||||
void *mask_ptr = mask.data_ptr();
|
||||
|
||||
// Get current stream
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
|
||||
// Get scalar type
|
||||
at::ScalarType scalar_type = input.scalar_type();
|
||||
|
||||
// Create and configure OpCommand
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("get_masked_input_and_mask");
|
||||
cmd.SetCustomHandler([scalar_type, size, stream,
|
||||
input_ptr, masked_input_ptr, mask_ptr,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
num_org_vocab_padding, added_vocab_start_index,
|
||||
added_vocab_end_index]() -> int {
|
||||
// Get platform info
|
||||
fe::PlatFormInfos platform_infos;
|
||||
int device_id = 0;
|
||||
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
|
||||
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
|
||||
uint32_t loop_cnt = (size + aivNum - 1) / aivNum;
|
||||
|
||||
// Call implementation
|
||||
get_masked_input_and_mask_impl(
|
||||
stream,
|
||||
input_ptr,
|
||||
masked_input_ptr,
|
||||
mask_ptr,
|
||||
org_vocab_start_index,
|
||||
org_vocab_end_index,
|
||||
num_org_vocab_padding,
|
||||
added_vocab_start_index,
|
||||
added_vocab_end_index,
|
||||
size,
|
||||
loop_cnt,
|
||||
aivNum);
|
||||
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return {masked_input, mask};
|
||||
}
|
||||
|
||||
void verify_tensor(std::string const& name, at::Tensor const& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
@@ -194,6 +300,16 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
||||
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
||||
|
||||
ops.def(
|
||||
"get_masked_input_and_mask(Tensor input, "
|
||||
" int org_vocab_start_index, "
|
||||
" int org_vocab_end_index, "
|
||||
" int num_org_vocab_padding, "
|
||||
" int added_vocab_start_index, "
|
||||
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
|
||||
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
|
||||
|
||||
ops.def(
|
||||
"advance_step_flashattn_ascendc(int num_seqs, int num_queries, int block_size,"
|
||||
" Tensor! input_tokens, Tensor! sampled_token_ids, Tensor! input_positions,"
|
||||
|
||||
91
tests/ops/test_vocabparallelembedding.py
Normal file
91
tests/ops/test_vocabparallelembedding.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
import vllm_ascend.platform # noqa: F401
|
||||
|
||||
# Test parameters
|
||||
DTYPES = [torch.int32]
|
||||
#SHAPES = [(100,), (5, 20), (3, 4, 5)] # Various tensor shapes
|
||||
#SHAPES = [(3, 4, 8), (3, 4, 5)] # Various tensor shapes
|
||||
SHAPES = [(3, 4, 3)]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def get_masked_input_and_mask_ref(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference implementation for verification"""
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||
org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
masked_input = vocab_mask * (input_ - valid_offset)
|
||||
return masked_input, ~vocab_mask
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_get_masked_input_and_mask(
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
# Set random seed
|
||||
torch.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Generate random input tensor
|
||||
input_tensor = torch.randint(0, 1000, shape, dtype=dtype)
|
||||
|
||||
# Test parameters
|
||||
test_case = {
|
||||
"org_start": 100,
|
||||
"org_end": 200,
|
||||
"padding": 0,
|
||||
"added_start": 300,
|
||||
"added_end": 400,
|
||||
}
|
||||
|
||||
# Get reference result
|
||||
ref_masked_input, ref_mask = get_masked_input_and_mask_ref(
|
||||
input_tensor, test_case["org_start"], test_case["org_end"],
|
||||
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
||||
|
||||
# Get custom op result
|
||||
print("input_tensor:", input_tensor)
|
||||
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
|
||||
input_tensor, test_case["org_start"], test_case["org_end"],
|
||||
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
||||
|
||||
ref_masked_input = ref_masked_input.to(dtype)
|
||||
print("custom_masked_input:", custom_masked_input)
|
||||
print("ref_masked_input:", ref_masked_input)
|
||||
print("custom_mask:", custom_mask)
|
||||
print("ref_mask:", ref_mask)
|
||||
# Compare results
|
||||
torch.testing.assert_close(
|
||||
custom_masked_input,
|
||||
ref_masked_input,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Masked input mismatch for case: {test_case}")
|
||||
torch.testing.assert_close(custom_mask,
|
||||
ref_mask,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Mask mismatch for case: {test_case}")
|
||||
Reference in New Issue
Block a user