[Ops][Refactor] Remove custom rotary_embedding operator (#6523)
### What this PR does / why we need it? This PR removes the custom `rotary_embedding` operator and its associated C++ kernel implementation, PyTorch bindings, and tests. The codebase now falls back to using the native `torch_npu._npu_rotary_embedding` implementation. This change simplifies the codebase by removing custom, platform-specific kernel code and relying on the standard NPU library implementation, which is presumably more optimized and easier to maintain. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring and does not introduce any user-facing changes. ### How was this patch tested? The tests for the custom `rotary_embedding` operator have been removed along with the operator itself. The correctness of the fallback to the native `torch_npu` implementation is verified by existing CI tests for attention layers and models that use rotary embeddings. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1,372 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include <stdio.h>
|
||||
#include "types.h"
|
||||
#include "utils.h"
|
||||
|
||||
|
||||
using vllm_ascend::AccType;
|
||||
using vllm_ascend::local_mem_copy;
|
||||
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
||||
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
|
||||
// retrieve this size from runtime for more Soc support
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
static int constexpr loadSize = 512;
|
||||
#else
|
||||
static int constexpr loadSize = 1024 * 4;
|
||||
#endif
|
||||
using dst_t = scalar_t;
|
||||
using acc_t = typename AccType<scalar_t>::type;
|
||||
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
|
||||
using local_scalar_t = AscendC::LocalTensor<scalar_t>;
|
||||
using local_acc_t = AscendC::LocalTensor<acc_t>;
|
||||
using local_dst_t = AscendC::LocalTensor<dst_t>;
|
||||
|
||||
public:
|
||||
__aicore__ inline RotaryEmbedding()
|
||||
{
|
||||
}
|
||||
|
||||
// Allocate buffers for input and output queue and the temp buffer used during kernel compute process,
|
||||
// this init process happens only in the kernel compute on a single vector core.
|
||||
__aicore__ inline void init(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
|
||||
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
|
||||
const int rotDim, const int64_t dstQueryStride,
|
||||
const int64_t dstKeyStride, const int64_t queryStride, const int64_t keyStride,
|
||||
const int numHeads, const int numKvHeads, const int headSize, AscendC::TPipe *pipe)
|
||||
{
|
||||
pipe_ = pipe;
|
||||
rotDim_ = rotDim;
|
||||
// query stride and key stride is used to handle the strided tensor which is not contiguous on num_tokens dim
|
||||
queryStride_ = queryStride;
|
||||
keyStride_ = keyStride;
|
||||
dstQueryStride_ = dstQueryStride;
|
||||
dstKeyStride_ = dstKeyStride;
|
||||
numHeads_ = numHeads;
|
||||
numKvHeads_ = numKvHeads;
|
||||
headSize_ = headSize;
|
||||
embedDim_ = rotDim / 2;
|
||||
|
||||
pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
||||
pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */);
|
||||
pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
||||
// 2 temporary calculation buffer
|
||||
calcTmpBufferOffset_ = 0;
|
||||
// 1 upcast buffer for bf16 (headSize)
|
||||
upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2;
|
||||
// 1 upcast temp buffer for bf16 (2 * embed_dim)
|
||||
upcastTempBufferOffset_ = upcastInputBufferOffset_ + sizeof(acc_t) * headSize_;
|
||||
// 2 sin cos upcast buffer for bf16
|
||||
cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_;
|
||||
// 2. bf16 path: needs 2 cos sin upcast buffer size
|
||||
// 3. fp16 path: needs 2 temporary calculation buffer size
|
||||
tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t);
|
||||
// need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case
|
||||
// 2 temporary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp
|
||||
// buffer(headSize int8), 1 dst_temp buffer(headSize, int32)
|
||||
pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */);
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
pipe_->InitBuffer(copyBuf_, loadSize);
|
||||
}
|
||||
}
|
||||
__aicore__ inline void update_mem_offset(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
|
||||
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
|
||||
const int rotDim, const int64_t dstQueryStride, const int64_t dstKeyStride,
|
||||
const int64_t queryStride, const int64_t keyStride, const int numHeads,
|
||||
const int numKvHeads, const int headSize, const int64_t idx)
|
||||
{
|
||||
int64_t pos = positions[idx];
|
||||
cosSin_.SetGlobalBuffer(cosSinCache + pos * rotDim_, rotDim_);
|
||||
query_.SetGlobalBuffer(query + queryStride * idx, headSize * numHeads_);
|
||||
key_.SetGlobalBuffer(key + keyStride * idx, headSize * numKvHeads_);
|
||||
queryDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(queryDst) + dstQueryStride * idx,
|
||||
headSize * numHeads_);
|
||||
keyDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(keyDst) + dstKeyStride * idx, headSize * numKvHeads_);
|
||||
}
|
||||
|
||||
// compute per head for neox on bf16
|
||||
template <typename acc_t_, typename std::enable_if<!std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
|
||||
__aicore__ inline void
|
||||
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
|
||||
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
|
||||
{
|
||||
// slice dst
|
||||
local_dst_t dstX = dst;
|
||||
local_dst_t dstY = dst[embedDim_];
|
||||
|
||||
// slice src
|
||||
local_scalar_t srcX = src;
|
||||
local_scalar_t srcY = src[embedDim_];
|
||||
|
||||
// slice temp buffer
|
||||
local_acc_t calcTmpBufferX = calcTmpBuffer;
|
||||
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
|
||||
|
||||
// slice upcast input buffer
|
||||
local_acc_t upcastBufferX = upcastInputBuffer;
|
||||
local_acc_t upcastBufferY = upcastBufferX[embedDim_];
|
||||
|
||||
// dst x calc
|
||||
Cast(upcastInputBuffer, src, AscendC::RoundMode::CAST_NONE, headSize_);
|
||||
Mul(calcTmpBufferX, upcastBufferX, cos, embedDim_);
|
||||
Mul(calcTmpBufferY, upcastBufferY, sin, embedDim_);
|
||||
Sub(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
Cast(dstX, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
|
||||
|
||||
// dst y calc
|
||||
Mul(calcTmpBufferX, upcastBufferX, sin, embedDim_);
|
||||
Mul(calcTmpBufferY, upcastBufferY, cos, embedDim_);
|
||||
Add(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
Cast(dstY, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
|
||||
}
|
||||
|
||||
// compute per head output for neox
|
||||
template <typename acc_t_, typename std::enable_if<std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
|
||||
__aicore__ inline void
|
||||
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
|
||||
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
|
||||
{
|
||||
// slice dst buffer
|
||||
local_dst_t dstX = dst;
|
||||
local_dst_t dstY = dst[embedDim_];
|
||||
// slice src buffer
|
||||
local_scalar_t srcX = src;
|
||||
local_scalar_t srcY = src[embedDim_];
|
||||
// slice temp buffer
|
||||
local_acc_t calcTmpBufferX = calcTmpBuffer;
|
||||
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
|
||||
|
||||
// dst x calc
|
||||
Mul(calcTmpBufferX, srcX, cos, embedDim_);
|
||||
Mul(calcTmpBufferY, srcY, sin, embedDim_);
|
||||
Sub(dstX, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
|
||||
// dst y calc
|
||||
Mul(calcTmpBufferX, srcX, sin, embedDim_);
|
||||
Mul(calcTmpBufferY, srcY, cos, embedDim_);
|
||||
Add(dstY, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
}
|
||||
|
||||
__aicore__ inline void compute_qk(AscendC::GlobalTensor<scalar_t> srcG, AscendC::GlobalTensor<dst_t> dstG,
|
||||
local_acc_t localCos, local_acc_t localSin, local_acc_t upcastInputBuffer,
|
||||
local_acc_t calcTmpBuffer, int loopCnt, int tailHeads, int loadStride,
|
||||
int headNumPerLoad)
|
||||
{
|
||||
for (int loopNum = 0; loopNum < loopCnt; ++loopNum) {
|
||||
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
|
||||
local_dst_t dst = outQue_.AllocTensor<dst_t>();
|
||||
AscendC::DataCopy(src, srcG[loopNum * loadStride], loadStride);
|
||||
inQue_.EnQue(src);
|
||||
|
||||
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
int elem_num = loadStride / sizeof(scalar_t);
|
||||
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
|
||||
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
} else {
|
||||
local_mem_copy(dst, srcDeque, loadStride);
|
||||
}
|
||||
for (int i = 0; i < headNumPerLoad; ++i) {
|
||||
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
|
||||
calcTmpBuffer);
|
||||
}
|
||||
outQue_.EnQue(dst);
|
||||
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
|
||||
AscendC::DataCopy(dstG[loopNum * loadStride], dstDeque, loadStride);
|
||||
outQue_.FreeTensor(dstDeque);
|
||||
inQue_.FreeTensor(srcDeque);
|
||||
}
|
||||
// process tail
|
||||
{
|
||||
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
|
||||
local_dst_t dst = outQue_.AllocTensor<dst_t>();
|
||||
|
||||
AscendC::DataCopy(src, srcG[loopCnt * loadStride], tailHeads * headSize_);
|
||||
inQue_.EnQue(src);
|
||||
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
|
||||
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
int elem_num = tailHeads * headSize_ / sizeof(scalar_t);
|
||||
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
|
||||
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
} else {
|
||||
local_mem_copy(dst, srcDeque, tailHeads * headSize_);
|
||||
}
|
||||
|
||||
for (int i = 0; i < tailHeads; ++i) {
|
||||
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
|
||||
calcTmpBuffer);
|
||||
}
|
||||
outQue_.EnQue(dst);
|
||||
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
|
||||
AscendC::DataCopy(dstG[loopCnt * loadStride], dstDeque, tailHeads * headSize_);
|
||||
outQue_.FreeTensor(dstDeque);
|
||||
inQue_.FreeTensor(srcDeque);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void compute_function()
|
||||
{
|
||||
local_scalar_t cosSinLocal = inQueSinCos_.AllocTensor<scalar_t>();
|
||||
|
||||
AscendC::DataCopy(cosSinLocal, cosSin_, embedDim_ * 2);
|
||||
|
||||
inQueSinCos_.EnQue(cosSinLocal);
|
||||
local_scalar_t localSinCosDeque = inQueSinCos_.DeQue<scalar_t>();
|
||||
local_scalar_t localCos = localSinCosDeque;
|
||||
local_scalar_t localSin = localSinCosDeque[embedDim_];
|
||||
|
||||
local_acc_t calcTmpBuffer;
|
||||
local_acc_t upcastInputBuffer;
|
||||
local_acc_t upcastTempBuffer;
|
||||
local_acc_t cosSinUpcastBuffer;
|
||||
local_acc_t scaleBuffer;
|
||||
local_acc_t offsetBuffer;
|
||||
calcTmpBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, calcTmpBufferOffset_);
|
||||
upcastInputBuffer = calcBuf_.GetWithOffset<acc_t>(headSize_, upcastInputBufferOffset_);
|
||||
upcastTempBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, upcastTempBufferOffset_);
|
||||
cosSinUpcastBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, cosSinUpcastBufferOffset_);
|
||||
|
||||
local_acc_t cosAccBuffer;
|
||||
local_acc_t sinAccBuffer;
|
||||
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
Cast(cosSinUpcastBuffer, localSinCosDeque, AscendC::RoundMode::CAST_NONE, 2 * embedDim_);
|
||||
cosAccBuffer = cosSinUpcastBuffer;
|
||||
sinAccBuffer = cosSinUpcastBuffer[embedDim_];
|
||||
} else {
|
||||
cosAccBuffer = localCos;
|
||||
sinAccBuffer = localSin;
|
||||
}
|
||||
|
||||
constexpr const int loadSizeByElem = loadSize / sizeof(scalar_t);
|
||||
int64_t headNumPerLoad = loadSizeByElem / headSize_;
|
||||
int64_t loopCnt = numHeads_ / headNumPerLoad;
|
||||
int64_t tailHeads = numHeads_ - loopCnt * headNumPerLoad;
|
||||
int64_t loadStride = headNumPerLoad * headSize_;
|
||||
int64_t loopCntKv = numKvHeads_ / headNumPerLoad;
|
||||
int64_t tailHeadsKv = numKvHeads_ - loopCntKv * headNumPerLoad;
|
||||
compute_qk(query_, queryDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer,
|
||||
calcTmpBuffer, loopCnt, tailHeads, loadStride, headNumPerLoad);
|
||||
|
||||
compute_qk(key_, keyDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, calcTmpBuffer,
|
||||
loopCntKv, tailHeadsKv, loadStride, headNumPerLoad);
|
||||
|
||||
inQueSinCos_.FreeTensor(localSinCosDeque);
|
||||
}
|
||||
|
||||
private:
|
||||
AscendC::TPipe *pipe_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQue_, inQueSinCos_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQue_;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> calcBuf_;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> copyBuf_;
|
||||
AscendC::GlobalTensor<dst_t> queryDst_;
|
||||
AscendC::GlobalTensor<dst_t> keyDst_;
|
||||
AscendC::GlobalTensor<scalar_t> query_;
|
||||
AscendC::GlobalTensor<scalar_t> key_;
|
||||
AscendC::GlobalTensor<scalar_t> cosSin_;
|
||||
int rotDim_;
|
||||
int embedDim_;
|
||||
int64_t queryStride_;
|
||||
int64_t keyStride_;
|
||||
int64_t dstQueryStride_;
|
||||
int64_t dstKeyStride_;
|
||||
int numHeads_;
|
||||
int numKvHeads_;
|
||||
int headSize_;
|
||||
int calcTmpBufferOffset_;
|
||||
int upcastInputBufferOffset_;
|
||||
int upcastTempBufferOffset_;
|
||||
int cosSinUpcastBufferOffset_;
|
||||
int tempBufferSize_;
|
||||
};
|
||||
|
||||
// Note: Need to use macro to instaniate all the target functions here, for the current build system dose not support template call in cpp
|
||||
// We use C style symbol here for kernel compilation, cpp style kernel entry may lead to compilation failure
|
||||
#define ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, NEOX) \
|
||||
extern "C" __global__ __aicore__ void rope_custom_##NEOX##_##TYPE( \
|
||||
__gm__ int64_t* positions, __gm__ void* queryDst, __gm__ void* keyDst, __gm__ TYPE* query, __gm__ TYPE* key, \
|
||||
__gm__ TYPE* cosSinCache, const int rotDim, const int64_t queryStride, const int64_t keyStride, \
|
||||
const int64_t dstQueryStride, const int64_t dstKeyStride, const int numHeads, const int numKvHeads, \
|
||||
const int headSize, const int64_t numTokens, const int loopNum, const int coreNum) \
|
||||
{ \
|
||||
AscendC::TPipe pipe; \
|
||||
RotaryEmbedding<TYPE, NEOX> op{}; \
|
||||
op.init(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
|
||||
queryStride, keyStride, numHeads, numKvHeads, headSize, &pipe); \
|
||||
for (int64_t i = AscendC::GetBlockIdx(); i < numTokens; i += coreNum) { \
|
||||
op.update_mem_offset(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
|
||||
queryStride, keyStride, numHeads, numKvHeads, headSize, i); \
|
||||
op.compute_function(); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define ROPE_CUSTOM_KERNEL_DECLARE(TYPE) \
|
||||
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, true); \
|
||||
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, false);
|
||||
|
||||
// Declare all the kernel entry here
|
||||
ROPE_CUSTOM_KERNEL_DECLARE(half)
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
|
||||
#endif
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
#define ROTARY_EMBEDDING_KERNEL_CALL(TYPE) \
|
||||
if (isNeox) \
|
||||
rope_custom_true_##TYPE<<<blockDim, nullptr, stream>>>( \
|
||||
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
|
||||
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
|
||||
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); \
|
||||
else \
|
||||
rope_custom_false_##TYPE<<<blockDim, nullptr, stream>>>( \
|
||||
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
|
||||
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
|
||||
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim);
|
||||
|
||||
// maximum number for runtime to launch a ascendc kernel.
|
||||
// we use this to constrain the maximum number of block size
|
||||
static const int64_t maxParallelSize = 65535;
|
||||
|
||||
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
|
||||
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
|
||||
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
|
||||
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
|
||||
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
|
||||
uint32_t aivNum)
|
||||
{
|
||||
|
||||
int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize;
|
||||
if (type == AscendType::FP16) {
|
||||
ROTARY_EMBEDDING_KERNEL_CALL(half);
|
||||
}
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
else if (type == AscendType::BF16) {
|
||||
ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
@@ -24,13 +24,6 @@
|
||||
#include "torch_npu/csrc/aten/common/from_blob.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
|
||||
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
|
||||
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
|
||||
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
|
||||
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
|
||||
uint32_t aivNum);
|
||||
|
||||
extern void get_masked_input_and_mask_impl(
|
||||
void* stream,
|
||||
void* input,
|
||||
|
||||
@@ -105,75 +105,6 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
||||
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
|
||||
{
|
||||
int32_t deviceId = 0;
|
||||
int64_t num_tokens = positions.numel();
|
||||
int positions_ndim = positions.dim();
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
key.size(0) == positions.size(0) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
key.size(1) == positions.size(1),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.numel() / num_tokens;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
|
||||
|
||||
// Make sure query and key have consistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key_hidden_size / head_size;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
|
||||
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
|
||||
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
|
||||
void *query_dst_ptr = query_dst.data_ptr();
|
||||
void *key_dst_ptr = key_dst.data_ptr();
|
||||
void *query_ptr = query.data_ptr();
|
||||
void *key_ptr = key.data_ptr();
|
||||
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.stride(seq_dim_idx);
|
||||
int64_t dst_query_stride = query_dst.stride(0);
|
||||
int64_t dst_key_stride = key_dst.stride(0);
|
||||
at::ScalarType scalar_type = query.scalar_type();
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("rotary_embedding");
|
||||
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
|
||||
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
|
||||
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
|
||||
auto dtype_num = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
int64_t aiv_num = 0;
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
uint32_t loop_cnt = (num_tokens + aiv_num - 1) / aiv_num;
|
||||
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
|
||||
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
|
||||
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return {query_dst, key_dst};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
|
||||
const c10::optional<at::Tensor> &descale0, const at::Tensor &gamma1, const c10::optional<at::Tensor> &beta1, const at::Tensor &wuq,
|
||||
@@ -1368,14 +1299,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
||||
|
||||
// Rotary embedding
|
||||
// Apply GPT-NeoX style rotary embedding to query and key.
|
||||
ops.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
||||
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
||||
|
||||
ops.def(
|
||||
"get_masked_input_and_mask(Tensor input, "
|
||||
" int org_vocab_start_index, "
|
||||
|
||||
@@ -36,24 +36,6 @@
|
||||
namespace vllm_ascend {
|
||||
namespace meta {
|
||||
const int64_t INT4_NUMS_IN_INT32 = 8;
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
|
||||
at::Tensor &positions,
|
||||
at::Tensor &query,
|
||||
at::Tensor &key,
|
||||
int64_t head_size,
|
||||
at::Tensor &cos_sin_cache,
|
||||
bool is_neox) {
|
||||
auto num_tokens = positions.sym_numel();
|
||||
auto query_hidden_size = query.sym_numel() / num_tokens;
|
||||
auto key_hidden_size = key.sym_numel() / num_tokens;
|
||||
|
||||
auto num_heads = query_hidden_size / head_size;
|
||||
auto num_kv_heads = key_hidden_size / head_size;
|
||||
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
|
||||
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
|
||||
|
||||
return {query_dst, key_dst};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
|
||||
at::Tensor &input,
|
||||
@@ -457,8 +439,6 @@ namespace {
|
||||
// the custom kernel been captured into aclgraph
|
||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
|
||||
// Rotary embedding meta implementation
|
||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||
// Masked input and mask meta implementation
|
||||
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
|
||||
// Bgmv expand
|
||||
|
||||
Reference in New Issue
Block a user