/* * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. */ #include "kernel_operator.h" #include "kernel_tensor_impl.h" #include "kernel_type.h" #include "types.h" #include "utils.h" using vllm_ascend::AccType; template class GetMaskedInputAndMask { public: __aicore__ inline GetMaskedInputAndMask() {} __aicore__ inline ~GetMaskedInputAndMask() { pipe.Reset(); } __aicore__ inline void Init( __gm__ scalar_t* input, __gm__ scalar_t* masked_input, __gm__ bool* mask_out, const int64_t org_vocab_start_index, const int64_t org_vocab_end_index, const int64_t num_org_vocab_padding, const int64_t added_vocab_start_index, const int64_t added_vocab_end_index, const int64_t size) { // Initialize basic parameters input_ = input; masked_input_ = masked_input; mask_out_ = mask_out; org_vocab_start_index_ = org_vocab_start_index; org_vocab_end_index_ = org_vocab_end_index; size_ = ((size + 31) / 32) * 32; added_offset_ = added_vocab_start_index - (org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding; added_vocab_start_index_ = added_vocab_start_index; added_vocab_end_index_ = added_vocab_end_index; // Initialize global tensors inputGlobal.SetGlobalBuffer(input); maskedOutputGlobal.SetGlobalBuffer(masked_input); maskOutGlobal.SetGlobalBuffer(mask_out); // Initialize queues pipe.InitBuffer(inQueue, 1, size_ * sizeof(scalar_t)); pipe.InitBuffer(outQueue, 1, size_ * sizeof(scalar_t)); pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool)); // Initialize calculation buffers // NOTE: calc_buf_1 and calc_buf_2 are also used for int16 casting on older archs. pipe.InitBuffer(calc_buf_1, size_ * sizeof(float)); pipe.InitBuffer(calc_buf_2, size_ * sizeof(float)); // Initialize result queues pipe.InitBuffer(result_ge_que, BUFFER_NUM, size_ * sizeof(float)); pipe.InitBuffer(result_le_que, BUFFER_NUM, size_ * sizeof(float)); pipe.InitBuffer(result_org_mask_que, BUFFER_NUM, size_ * sizeof(float)); pipe.InitBuffer(result_add_mask_que, BUFFER_NUM, size_ * sizeof(float)); // Initialize temporary buffers pipe.InitBuffer(start_buf, size_ * sizeof(float)); pipe.InitBuffer(end_buf, size_ * sizeof(float)); pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float)); // Also used for half intermediate in casting pipe.InitBuffer(validOffset_buf, size_ * sizeof(float)); pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t)); pipe.InitBuffer(ones_buf_, size_ * sizeof(float)); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor inputLocal = inQueue.AllocTensor(); AscendC::DataCopy(inputLocal, inputGlobal, size_); inQueue.EnQue(inputLocal); } __aicore__ inline void CompareWithValue( AscendC::LocalTensor& result, const AscendC::LocalTensor& input, const AscendC::LocalTensor& compare_value, bool is_greater_equal) { AscendC::LocalTensor compute_buf = calc_buf_1.Get(); if (is_greater_equal) { AscendC::Max(compute_buf, input, compare_value, size_); AscendC::Sub(compute_buf, compare_value, compute_buf, size_); } else { AscendC::Max(compute_buf, input, compare_value, size_); AscendC::Sub(compute_buf, compute_buf, compare_value, size_); } AscendC::Abs(compute_buf, compute_buf, size_); AscendC::Mins(compute_buf, compute_buf, MIN_ACCURACY_FP32, size_); AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_); AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_); AscendC::Muls(compute_buf, compute_buf, MAX_MUL_2_FP32, size_); AscendC::Adds(compute_buf, compute_buf, NEGATIVE_ONE_FP32, size_); AscendC::Abs(compute_buf, compute_buf, size_); AscendC::LocalTensor compute_buf_fp16 = calc_buf_2.Get(); AscendC::Cast(compute_buf_fp16, compute_buf, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(result, compute_buf_fp16, AscendC::RoundMode::CAST_NONE, size_); } __aicore__ inline void ComputeRangeMask( AscendC::LocalTensor& range_mask, const AscendC::LocalTensor& input, const float start_value, const float end_value) { AscendC::LocalTensor start_value_tensor = start_buf.Get(); AscendC::LocalTensor end_value_tensor = end_buf.Get(); AscendC::Duplicate(start_value_tensor, start_value, size_); AscendC::Duplicate(end_value_tensor, end_value, size_); AscendC::LocalTensor ge_result = result_ge_que.AllocTensor(); AscendC::LocalTensor lt_result = result_le_que.AllocTensor(); CompareWithValue(ge_result, start_value_tensor, input, true); CompareWithValue(lt_result, input, end_value_tensor, false); #if (__CCE_AICORE__ >= 220) AscendC::And(range_mask, ge_result, lt_result, size_); #else { // WORKAROUND for older arch // No direct int8->int16 cast. Use half as intermediate. // No direct int8 And. Use int16 And. AscendC::LocalTensor ge_result_i16 = calc_buf_1.Get(); AscendC::LocalTensor lt_result_i16 = calc_buf_2.Get(); AscendC::LocalTensor range_mask_i16 = ge_result_i16; // Use a temporary buffer for half type AscendC::LocalTensor tmp_half = inputFloat_buf.Get(); // 1. Cast inputs: int8_t -> half -> int16_t AscendC::Cast(tmp_half, ge_result, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(ge_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(tmp_half, lt_result, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(lt_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_); // 2. Perform And on int16_t tensors AscendC::And(range_mask_i16, ge_result_i16, lt_result_i16, size_); // 3. Cast result back: int16_t -> half -> int8_t AscendC::Cast(tmp_half, range_mask_i16, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(range_mask, tmp_half, AscendC::RoundMode::CAST_NONE, size_); } #endif } __aicore__ inline void Compute() { AscendC::LocalTensor inputLocal = inQueue.DeQue(); AscendC::LocalTensor maskedLocal = outQueue.AllocTensor(); AscendC::LocalTensor maskLocal = maskQueue.AllocTensor(); AscendC::LocalTensor inputFloat = inputFloat_buf.Get(); AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_); AscendC::LocalTensor orgVocabMask = result_org_mask_que.AllocTensor(); ComputeRangeMask(orgVocabMask, inputFloat, static_cast(org_vocab_start_index_), static_cast(org_vocab_end_index_)); AscendC::LocalTensor addedVocabMask = result_add_mask_que.AllocTensor(); ComputeRangeMask(addedVocabMask, inputFloat, static_cast(added_vocab_start_index_), static_cast(added_vocab_end_index_)); AscendC::LocalTensor validOffset = validOffset_buf.Get(); AscendC::LocalTensor constOrgStartIndex = start_buf.Get(); AscendC::Duplicate(constOrgStartIndex, float(org_vocab_start_index_), size_); AscendC::LocalTensor orgVocabMask_fp16; AscendC::LocalTensor orgVocabMask_fp32; AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_); AscendC::Mul(validOffset, constOrgStartIndex, orgVocabMask_fp32, size_); AscendC::LocalTensor addedOffset; AscendC::LocalTensor addedOffsetTensor = end_buf.Get(); AscendC::Duplicate(addedOffsetTensor, float(added_offset_), size_); AscendC::LocalTensor addedVocabMask_fp16; AscendC::LocalTensor addedVocabMask_fp32; AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_); AscendC::Mul(addedOffset, addedOffsetTensor, addedVocabMask_fp32, size_); AscendC::Add(validOffset, validOffset, addedOffset, size_); AscendC::LocalTensor vocabMask = vocabMask_buf_.Get(); #if (__CCE_AICORE__ >= 220) AscendC::Or(vocabMask, orgVocabMask, addedVocabMask, size_); #else { // WORKAROUND for older arch // No direct int8->int16 cast. Use half as intermediate. // No direct int8 Or. Use int16 Or. AscendC::LocalTensor orgVocabMask_i16 = calc_buf_1.Get(); AscendC::LocalTensor addedVocabMask_i16 = calc_buf_2.Get(); AscendC::LocalTensor vocabMask_i16 = orgVocabMask_i16; // Use a temporary buffer for half type. inputFloat_buf is free now. AscendC::LocalTensor tmp_half = inputFloat_buf.Get(); // 1. Cast inputs: int8_t -> half -> int16_t AscendC::Cast(tmp_half, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(orgVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(tmp_half, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(addedVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_); // 2. Perform Or on int16_t tensors AscendC::Or(vocabMask_i16, orgVocabMask_i16, addedVocabMask_i16, size_); // 3. Cast result back: int16_t -> half -> int8_t AscendC::Cast(tmp_half, vocabMask_i16, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(vocabMask, tmp_half, AscendC::RoundMode::CAST_NONE, size_); } #endif AscendC::Sub(inputFloat, inputFloat, validOffset, size_); AscendC::LocalTensor vocabMask_fp16; AscendC::LocalTensor vocabMask_fp32; AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_); AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_); AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_); outQueue.EnQue(maskedLocal); AscendC::LocalTensor ones_tensor = ones_buf_.Get(); AscendC::Duplicate(ones_tensor, (float)1, size_); AscendC::LocalTensor maskLocal_fp32; AscendC::Sub(maskLocal_fp32, ones_tensor, vocabMask_fp32, size_); AscendC::LocalTensor maskLocal_fp16; AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_); AscendC::Cast(maskLocal, maskLocal_fp16, AscendC::RoundMode::CAST_NONE, size_); maskQueue.EnQue(maskLocal); inQueue.FreeTensor(inputLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor maskedLocal = outQueue.DeQue(); AscendC::LocalTensor maskLocal = maskQueue.DeQue(); AscendC::DataCopy(maskedOutputGlobal, maskedLocal, size_); AscendC::DataCopy(maskOutGlobal, maskLocal, size_); outQueue.FreeTensor(maskedLocal); maskQueue.FreeTensor(maskLocal); } private: static constexpr int32_t BUFFER_NUM = 2; AscendC::TPipe pipe; AscendC::TQue inQueue; AscendC::TQue outQueue, maskQueue; AscendC::GlobalTensor inputGlobal, maskedOutputGlobal; AscendC::GlobalTensor maskOutGlobal; AscendC::TBuf calc_buf_1; AscendC::TBuf calc_buf_2; AscendC::TQue result_ge_que; AscendC::TQue result_le_que; AscendC::TQue result_org_mask_que; AscendC::TQue result_add_mask_que; // Temporary buffers AscendC::TBuf start_buf; AscendC::TBuf end_buf; AscendC::TBuf inputFloat_buf; AscendC::TBuf validOffset_buf; AscendC::TBuf vocabMask_buf_; AscendC::TBuf ones_buf_; __gm__ scalar_t *input_, *masked_input_; __gm__ bool *mask_out_; int64_t size_; int64_t org_vocab_start_index_, org_vocab_end_index_; int64_t added_vocab_start_index_, added_vocab_end_index_; int64_t added_offset_; static constexpr float MIN_ACCURACY_FP32 = 1.1754943508222875e-38; static constexpr float MAX_MUL_1_FP32 = 1125899906842624; static constexpr float MAX_MUL_2_FP32 = 67108864; static constexpr float NEGATIVE_ONE_FP32 = -1.0f; }; extern "C" __global__ __aicore__ void get_masked_input_and_mask_kernel( __gm__ int32_t* input, __gm__ int32_t* masked_input, __gm__ bool* mask_out, const int64_t org_vocab_start_index, const int64_t org_vocab_end_index, const int64_t num_org_vocab_padding, const int64_t added_vocab_start_index, const int64_t added_vocab_end_index, const int64_t size, const uint32_t loop_cnt, const uint32_t aiv_num) { { GetMaskedInputAndMask op{}; for (int64_t i = AscendC::GetBlockIdx(); i < loop_cnt; i += aiv_num) { op.Init(input + i * size/loop_cnt, masked_input + i * size/loop_cnt, mask_out + i * size/loop_cnt, org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, added_vocab_end_index, size/loop_cnt); op.Process(); } } // op destructor called here } namespace vllm_ascend { void get_masked_input_and_mask_impl( void* stream, void* input, void* masked_input, void* mask_out, const int64_t org_vocab_start_index, const int64_t org_vocab_end_index, const int64_t num_org_vocab_padding, const int64_t added_vocab_start_index, const int64_t added_vocab_end_index, const int64_t size, const uint32_t loop_cnt, const uint32_t aiv_num) { get_masked_input_and_mask_kernel<<>>( static_cast(input), static_cast(masked_input), static_cast(mask_out), org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, added_vocab_end_index, size, loop_cnt, aiv_num); } } // namespace vllm_ascend