[Ops][Refactor] Remove custom rotary_embedding operator (#6523)

### What this PR does / why we need it?
This PR removes the custom `rotary_embedding` operator and its
associated C++ kernel implementation, PyTorch bindings, and tests.

The codebase now falls back to using the native
`torch_npu._npu_rotary_embedding` implementation. This change simplifies
the codebase by removing custom, platform-specific kernel code and
relying on the standard NPU library implementation, which is presumably
more optimized and easier to maintain.

### Does this PR introduce _any_ user-facing change?
No. This is an internal refactoring and does not introduce any
user-facing changes.

### How was this patch tested?
The tests for the custom `rotary_embedding` operator have been removed
along with the operator itself. The correctness of the fallback to the
native `torch_npu` implementation is verified by existing CI tests for
attention layers and models that use rotary embeddings.

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2026-02-07 09:24:05 +08:00
committed by GitHub
parent 06aa6036f6
commit 6c49f95da2
8 changed files with 59 additions and 1392 deletions

View File

@@ -1,372 +0,0 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include <stdio.h>
#include "types.h"
#include "utils.h"
using vllm_ascend::AccType;
using vllm_ascend::local_mem_copy;
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
// retrieve this size from runtime for more Soc support
#if (__CCE_AICORE__ >= 220)
static int constexpr loadSize = 512;
#else
static int constexpr loadSize = 1024 * 4;
#endif
using dst_t = scalar_t;
using acc_t = typename AccType<scalar_t>::type;
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
using local_scalar_t = AscendC::LocalTensor<scalar_t>;
using local_acc_t = AscendC::LocalTensor<acc_t>;
using local_dst_t = AscendC::LocalTensor<dst_t>;
public:
__aicore__ inline RotaryEmbedding()
{
}
// Allocate buffers for input and output queue and the temp buffer used during kernel compute process,
// this init process happens only in the kernel compute on a single vector core.
__aicore__ inline void init(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
const int rotDim, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int64_t queryStride, const int64_t keyStride,
const int numHeads, const int numKvHeads, const int headSize, AscendC::TPipe *pipe)
{
pipe_ = pipe;
rotDim_ = rotDim;
// query stride and key stride is used to handle the strided tensor which is not contiguous on num_tokens dim
queryStride_ = queryStride;
keyStride_ = keyStride;
dstQueryStride_ = dstQueryStride;
dstKeyStride_ = dstKeyStride;
numHeads_ = numHeads;
numKvHeads_ = numKvHeads;
headSize_ = headSize;
embedDim_ = rotDim / 2;
pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */);
pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
// 2 temporary calculation buffer
calcTmpBufferOffset_ = 0;
// 1 upcast buffer for bf16 (headSize)
upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2;
// 1 upcast temp buffer for bf16 (2 * embed_dim)
upcastTempBufferOffset_ = upcastInputBufferOffset_ + sizeof(acc_t) * headSize_;
// 2 sin cos upcast buffer for bf16
cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_;
// 2. bf16 path: needs 2 cos sin upcast buffer size
// 3. fp16 path: needs 2 temporary calculation buffer size
tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t);
// need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case
// 2 temporary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp
// buffer(headSize int8), 1 dst_temp buffer(headSize, int32)
pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */);
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
pipe_->InitBuffer(copyBuf_, loadSize);
}
}
__aicore__ inline void update_mem_offset(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
const int rotDim, const int64_t dstQueryStride, const int64_t dstKeyStride,
const int64_t queryStride, const int64_t keyStride, const int numHeads,
const int numKvHeads, const int headSize, const int64_t idx)
{
int64_t pos = positions[idx];
cosSin_.SetGlobalBuffer(cosSinCache + pos * rotDim_, rotDim_);
query_.SetGlobalBuffer(query + queryStride * idx, headSize * numHeads_);
key_.SetGlobalBuffer(key + keyStride * idx, headSize * numKvHeads_);
queryDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(queryDst) + dstQueryStride * idx,
headSize * numHeads_);
keyDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(keyDst) + dstKeyStride * idx, headSize * numKvHeads_);
}
// compute per head for neox on bf16
template <typename acc_t_, typename std::enable_if<!std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
__aicore__ inline void
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
{
// slice dst
local_dst_t dstX = dst;
local_dst_t dstY = dst[embedDim_];
// slice src
local_scalar_t srcX = src;
local_scalar_t srcY = src[embedDim_];
// slice temp buffer
local_acc_t calcTmpBufferX = calcTmpBuffer;
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
// slice upcast input buffer
local_acc_t upcastBufferX = upcastInputBuffer;
local_acc_t upcastBufferY = upcastBufferX[embedDim_];
// dst x calc
Cast(upcastInputBuffer, src, AscendC::RoundMode::CAST_NONE, headSize_);
Mul(calcTmpBufferX, upcastBufferX, cos, embedDim_);
Mul(calcTmpBufferY, upcastBufferY, sin, embedDim_);
Sub(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
Cast(dstX, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
// dst y calc
Mul(calcTmpBufferX, upcastBufferX, sin, embedDim_);
Mul(calcTmpBufferY, upcastBufferY, cos, embedDim_);
Add(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
Cast(dstY, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
}
// compute per head output for neox
template <typename acc_t_, typename std::enable_if<std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
__aicore__ inline void
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
{
// slice dst buffer
local_dst_t dstX = dst;
local_dst_t dstY = dst[embedDim_];
// slice src buffer
local_scalar_t srcX = src;
local_scalar_t srcY = src[embedDim_];
// slice temp buffer
local_acc_t calcTmpBufferX = calcTmpBuffer;
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
// dst x calc
Mul(calcTmpBufferX, srcX, cos, embedDim_);
Mul(calcTmpBufferY, srcY, sin, embedDim_);
Sub(dstX, calcTmpBufferX, calcTmpBufferY, embedDim_);
// dst y calc
Mul(calcTmpBufferX, srcX, sin, embedDim_);
Mul(calcTmpBufferY, srcY, cos, embedDim_);
Add(dstY, calcTmpBufferX, calcTmpBufferY, embedDim_);
}
__aicore__ inline void compute_qk(AscendC::GlobalTensor<scalar_t> srcG, AscendC::GlobalTensor<dst_t> dstG,
local_acc_t localCos, local_acc_t localSin, local_acc_t upcastInputBuffer,
local_acc_t calcTmpBuffer, int loopCnt, int tailHeads, int loadStride,
int headNumPerLoad)
{
for (int loopNum = 0; loopNum < loopCnt; ++loopNum) {
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
local_dst_t dst = outQue_.AllocTensor<dst_t>();
AscendC::DataCopy(src, srcG[loopNum * loadStride], loadStride);
inQue_.EnQue(src);
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
int elem_num = loadStride / sizeof(scalar_t);
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
} else {
local_mem_copy(dst, srcDeque, loadStride);
}
for (int i = 0; i < headNumPerLoad; ++i) {
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
calcTmpBuffer);
}
outQue_.EnQue(dst);
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
AscendC::DataCopy(dstG[loopNum * loadStride], dstDeque, loadStride);
outQue_.FreeTensor(dstDeque);
inQue_.FreeTensor(srcDeque);
}
// process tail
{
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
local_dst_t dst = outQue_.AllocTensor<dst_t>();
AscendC::DataCopy(src, srcG[loopCnt * loadStride], tailHeads * headSize_);
inQue_.EnQue(src);
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
int elem_num = tailHeads * headSize_ / sizeof(scalar_t);
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
} else {
local_mem_copy(dst, srcDeque, tailHeads * headSize_);
}
for (int i = 0; i < tailHeads; ++i) {
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
calcTmpBuffer);
}
outQue_.EnQue(dst);
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
AscendC::DataCopy(dstG[loopCnt * loadStride], dstDeque, tailHeads * headSize_);
outQue_.FreeTensor(dstDeque);
inQue_.FreeTensor(srcDeque);
}
}
__aicore__ inline void compute_function()
{
local_scalar_t cosSinLocal = inQueSinCos_.AllocTensor<scalar_t>();
AscendC::DataCopy(cosSinLocal, cosSin_, embedDim_ * 2);
inQueSinCos_.EnQue(cosSinLocal);
local_scalar_t localSinCosDeque = inQueSinCos_.DeQue<scalar_t>();
local_scalar_t localCos = localSinCosDeque;
local_scalar_t localSin = localSinCosDeque[embedDim_];
local_acc_t calcTmpBuffer;
local_acc_t upcastInputBuffer;
local_acc_t upcastTempBuffer;
local_acc_t cosSinUpcastBuffer;
local_acc_t scaleBuffer;
local_acc_t offsetBuffer;
calcTmpBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, calcTmpBufferOffset_);
upcastInputBuffer = calcBuf_.GetWithOffset<acc_t>(headSize_, upcastInputBufferOffset_);
upcastTempBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, upcastTempBufferOffset_);
cosSinUpcastBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, cosSinUpcastBufferOffset_);
local_acc_t cosAccBuffer;
local_acc_t sinAccBuffer;
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
Cast(cosSinUpcastBuffer, localSinCosDeque, AscendC::RoundMode::CAST_NONE, 2 * embedDim_);
cosAccBuffer = cosSinUpcastBuffer;
sinAccBuffer = cosSinUpcastBuffer[embedDim_];
} else {
cosAccBuffer = localCos;
sinAccBuffer = localSin;
}
constexpr const int loadSizeByElem = loadSize / sizeof(scalar_t);
int64_t headNumPerLoad = loadSizeByElem / headSize_;
int64_t loopCnt = numHeads_ / headNumPerLoad;
int64_t tailHeads = numHeads_ - loopCnt * headNumPerLoad;
int64_t loadStride = headNumPerLoad * headSize_;
int64_t loopCntKv = numKvHeads_ / headNumPerLoad;
int64_t tailHeadsKv = numKvHeads_ - loopCntKv * headNumPerLoad;
compute_qk(query_, queryDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer,
calcTmpBuffer, loopCnt, tailHeads, loadStride, headNumPerLoad);
compute_qk(key_, keyDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, calcTmpBuffer,
loopCntKv, tailHeadsKv, loadStride, headNumPerLoad);
inQueSinCos_.FreeTensor(localSinCosDeque);
}
private:
AscendC::TPipe *pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQue_, inQueSinCos_;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQue_;
AscendC::TBuf<AscendC::TPosition::VECCALC> calcBuf_;
AscendC::TBuf<AscendC::TPosition::VECCALC> copyBuf_;
AscendC::GlobalTensor<dst_t> queryDst_;
AscendC::GlobalTensor<dst_t> keyDst_;
AscendC::GlobalTensor<scalar_t> query_;
AscendC::GlobalTensor<scalar_t> key_;
AscendC::GlobalTensor<scalar_t> cosSin_;
int rotDim_;
int embedDim_;
int64_t queryStride_;
int64_t keyStride_;
int64_t dstQueryStride_;
int64_t dstKeyStride_;
int numHeads_;
int numKvHeads_;
int headSize_;
int calcTmpBufferOffset_;
int upcastInputBufferOffset_;
int upcastTempBufferOffset_;
int cosSinUpcastBufferOffset_;
int tempBufferSize_;
};
// Note: Need to use macro to instaniate all the target functions here, for the current build system dose not support template call in cpp
// We use C style symbol here for kernel compilation, cpp style kernel entry may lead to compilation failure
#define ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, NEOX) \
extern "C" __global__ __aicore__ void rope_custom_##NEOX##_##TYPE( \
__gm__ int64_t* positions, __gm__ void* queryDst, __gm__ void* keyDst, __gm__ TYPE* query, __gm__ TYPE* key, \
__gm__ TYPE* cosSinCache, const int rotDim, const int64_t queryStride, const int64_t keyStride, \
const int64_t dstQueryStride, const int64_t dstKeyStride, const int numHeads, const int numKvHeads, \
const int headSize, const int64_t numTokens, const int loopNum, const int coreNum) \
{ \
AscendC::TPipe pipe; \
RotaryEmbedding<TYPE, NEOX> op{}; \
op.init(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
queryStride, keyStride, numHeads, numKvHeads, headSize, &pipe); \
for (int64_t i = AscendC::GetBlockIdx(); i < numTokens; i += coreNum) { \
op.update_mem_offset(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
queryStride, keyStride, numHeads, numKvHeads, headSize, i); \
op.compute_function(); \
} \
}
#define ROPE_CUSTOM_KERNEL_DECLARE(TYPE) \
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, true); \
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, false);
// Declare all the kernel entry here
ROPE_CUSTOM_KERNEL_DECLARE(half)
#if (__CCE_AICORE__ >= 220)
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
#endif
namespace vllm_ascend {
#define ROTARY_EMBEDDING_KERNEL_CALL(TYPE) \
if (isNeox) \
rope_custom_true_##TYPE<<<blockDim, nullptr, stream>>>( \
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); \
else \
rope_custom_false_##TYPE<<<blockDim, nullptr, stream>>>( \
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim);
// maximum number for runtime to launch a ascendc kernel.
// we use this to constrain the maximum number of block size
static const int64_t maxParallelSize = 65535;
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
uint32_t aivNum)
{
int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize;
if (type == AscendType::FP16) {
ROTARY_EMBEDDING_KERNEL_CALL(half);
}
#if (__CCE_AICORE__ >= 220)
else if (type == AscendType::BF16) {
ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t);
}
#endif
else {
return;
}
}
} // namespace vllm_ascend

View File

@@ -24,13 +24,6 @@
#include "torch_npu/csrc/aten/common/from_blob.h"
namespace vllm_ascend {
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
uint32_t aivNum);
extern void get_masked_input_and_mask_impl(
void* stream,
void* input,

View File

@@ -105,75 +105,6 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
}
}
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
{
int32_t deviceId = 0;
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
int rot_dim = cos_sin_cache.size(1);
int seq_dim_idx = positions_ndim - 1;
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
void *query_dst_ptr = query_dst.data_ptr();
void *key_dst_ptr = key_dst.data_ptr();
void *query_ptr = query.data_ptr();
void *key_ptr = key.data_ptr();
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
int64_t dst_query_stride = query_dst.stride(0);
int64_t dst_key_stride = key_dst.stride(0);
at::ScalarType scalar_type = query.scalar_type();
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("rotary_embedding");
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
auto dtype_num = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
uint32_t loop_cnt = (num_tokens + aiv_num - 1) / aiv_num;
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num);
return 0;
});
cmd.Run();
return {query_dst, key_dst};
}
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
const c10::optional<at::Tensor> &descale0, const at::Tensor &gamma1, const c10::optional<at::Tensor> &beta1, const at::Tensor &wuq,
@@ -1368,14 +1299,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
// Rotary embedding
// Apply GPT-NeoX style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
ops.def(
"get_masked_input_and_mask(Tensor input, "
" int org_vocab_start_index, "

View File

@@ -36,24 +36,6 @@
namespace vllm_ascend {
namespace meta {
const int64_t INT4_NUMS_IN_INT32 = 8;
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
at::Tensor &positions,
at::Tensor &query,
at::Tensor &key,
int64_t head_size,
at::Tensor &cos_sin_cache,
bool is_neox) {
auto num_tokens = positions.sym_numel();
auto query_hidden_size = query.sym_numel() / num_tokens;
auto key_hidden_size = key.sym_numel() / num_tokens;
auto num_heads = query_hidden_size / head_size;
auto num_kv_heads = key_hidden_size / head_size;
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
return {query_dst, key_dst};
}
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
at::Tensor &input,
@@ -457,8 +439,6 @@ namespace {
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
// Bgmv expand

View File

@@ -1,351 +0,0 @@
# Copyright 2023 The vLLM team.
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
import gc
from typing import Optional, Tuple, Union
import pytest
import torch
import torch.nn as nn
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
# Only Neox style true scenario is supported for now
IS_NEOX_STYLE = [True]
DTYPES = [torch.half]
HEAD_SIZES = [64, 64, 96, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 4096] # Arbitrary values for testing
NUM_TOKENS = [10, 21]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
# Set tolerance to 1 for quant ops
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
# test with leading dimension and merge seqlen and batch_size as num_tokens
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_quant_with_leading_dim(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
rope = rope.to(dtype=dtype)
num_tokens = batch_size * seq_len
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
qkv_tensor = torch.randn(num_tokens,
num_heads * head_size * 3,
dtype=dtype)
query, key, _ = qkv_tensor.split(
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
dim=-1,
)
ref_query, ref_key = rope.forward_native(positions, query, key)
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,
key,
rope.head_size,
rope.cos_sin_cache,
rope.is_neox_style,
)
# Compare the results.
torch.testing.assert_close(query.view(ref_query.size()),
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key.view(ref_key.size()),
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
class ModelwithRotaryEmbedding(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
self.rope = RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
q,
k,
self.rope.head_size,
self.rope.cos_sin_cache,
self.rope.is_neox_style,
)
query = query.view(q.shape)
key = key.view(k.shape)
o = self.o_proj(query)
return o
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
# add a warmup graph replay for workaround
ACL_GRPAH_FIRST_RUN = True
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_capture_rotary_embedding_in_aclgraph(
is_neox_style: bool,
num_tokens: int,
num_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
max_position_embeddings: int = 8192,
base: int = 10000,
):
"""Test if the rotary embedding can be captured in aclgraph."""
torch.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
model = ModelwithRotaryEmbedding(
hidden_size=num_heads * head_size,
num_heads=num_heads,
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
# string match
graph = str(gm.graph)
assert "_C_ascend.rotary_embedding" in graph
return gm
static_positions = torch.randint(0, max_position_embeddings,
(num_tokens, ))
static_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
stream = torch.npu.Stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
# warmup the fx graph before capture
for i in range(3):
static_output = compiled_model(static_positions,
static_hidden_states,
offsets=None)
stream.wait_stream(torch.npu.current_stream())
aclgraph = torch.npu.NPUGraph()
with torch.npu.graph(aclgraph):
# Capture the model in aclgraph.
static_output = compiled_model(static_positions, static_hidden_states)
# Capture the model in aclgraph.
random_filled_positions = torch.randint(0,
max_position_embeddings,
(num_tokens, ),
device="npu")
random_filled_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
static_positions.copy_(random_filled_positions)
static_hidden_states.copy_(random_filled_hidden_states)
aclgraph.replay()
global ACL_GRPAH_FIRST_RUN
if ACL_GRPAH_FIRST_RUN:
ACL_GRPAH_FIRST_RUN = False
return
output_reference = model(static_positions, static_hidden_states)
torch.testing.assert_close(static_output,
output_reference,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()

View File

@@ -1,470 +0,0 @@
import math
import unittest
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from transformers.configuration_utils import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding)
from vllm.platforms import CpuArchEnum
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
from vllm_ascend.utils import AscendDeviceType
MODEL = "Qwen3-0.6B"
MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct"
MAX_NUM_BATCHED_TOKEND = 10000
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
def test_custom_rotary_embedding_enabled(self):
# Test when all conditions are True
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = _custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=False):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
class TestAscendRotaryEmbedding(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.config_patcher = patch('vllm.config.vllm.get_current_vllm_config')
self.mock_get_config = self.config_patcher.start()
mock_config = MagicMock()
mock_config.compilation_config.custom_ops = ["all"]
self.mock_get_config.return_value = mock_config
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.cos_sin_cache = torch.randn(3, 1, 32)
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
self.max_position, self.rope_theta,
self.is_neox_style, torch.float16)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = self.is_neox_style
@patch('torch.ops._C_ascend')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled,
mock_soc_version, mock__c):
mock__c.rotary_embedding.return_value = self.query, self.key
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_text_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock__c.rotary_embedding.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled):
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_text_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_with_offsets(self):
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_text_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
self.layer.forward(self.positions, self.query, self.key,
offsets)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled):
# Test neox_style override
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_text_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(
self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_rotary_dim_less_than_head_size(
self, mock_npu_rotary, mock_custom_enabled):
# test case when rotary_dim < head_size
org_rotary_dim = self.layer.rotary_dim
self.layer.rotary_dim = self.layer.head_size // 2
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_text_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
# restore rotary_dim
self.layer.rotary_dim = org_rotary_dim
class MockRopeModule:
def __init__(self, max_seq_len=2048, is_neox_style=True):
self.max_seq_len = max_seq_len
self.is_neox_style = is_neox_style
self.cos_cached = None
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
def setUp(self):
# Common setup for tests
self.config_patcher = patch('vllm.config.vllm.get_current_vllm_config')
self.mock_get_config = self.config_patcher.start()
mock_config = MagicMock()
mock_config.compilation_config.custom_ops = ["all"]
self.mock_get_config.return_value = mock_config
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.scaling_factor = 1
self.layer = None
def _create_layer(self):
self.layer = DeepseekScalingRotaryEmbedding(
self.head_size, self.rotary_dim, self.max_position,
self.rope_theta, self.is_neox_style, self.scaling_factor,
torch.float16)
return self.layer
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
return_value=(self.query,
self.key)) as mock_rope_forward_oot:
q_pe, k_pe = self.layer.forward(self.positions, self.query,
self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
key = torch.randn(1, 32)
mock_rope_forward_oot.return_value = (self.query, key)
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
mock_rope_forward_oot.return_value = (self.query, self.key)
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_basic_case(self, mock_npuplatform):
# Test with standard values
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
self.assertTrue(torch.allclose(result, expected))
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_yarn_get_mscale(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
# test_scale_less_than_or_equal_1
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
# test_scale_greater_than_1:
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = self.layer._yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")
class TestAscendMRotaryEmbedding(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.config_patcher = patch('vllm.config.vllm.get_current_vllm_config')
self.mock_get_config = self.config_patcher.start()
mock_config = MagicMock()
mock_config.compilation_config.custom_ops = ["all"]
self.mock_get_config.return_value = mock_config
self.number_tokens = 3
self.num_head = 8
self.num_kvhead = 8
self.head_size = 128
self.max_position_embeddings = 128000
self.is_neox_style = True
self.rope_theta = 1000000.0
self.positions_1d = torch.tensor([1, 2, 3])
self.positions_2d = torch.randint(1, 10, (3, self.number_tokens))
self.query = torch.randn(
(self.number_tokens, self.num_head * self.head_size),
dtype=torch.bfloat16)
self.key = torch.randn(
(self.number_tokens, self.num_kvhead * self.head_size),
dtype=torch.bfloat16)
# Qwen2.5-VL mrope section case
self.mrope_section = [16, 24, 24]
self.layer = MRotaryEmbedding(self.head_size,
self.head_size,
self.max_position_embeddings,
base=self.rope_theta,
is_neox_style=self.is_neox_style,
dtype=torch.bfloat16,
mrope_section=self.mrope_section)
self.mock_config = MagicMock()
def _create_vllm_config(self):
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL_VL,
tokenizer=MODEL_VL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_text_config = PretrainedConfig()
vllm_config.model_config = model_config
return vllm_config
@patch('torch_npu.npu_mrope')
@patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_forward_oot_1d_positions(self, mock_cpu_arc, mock_npu_mrope):
mock_cpu_arc.return_value = CpuArchEnum.ARM
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
torch.zeros_like(self.key))
vllm_config = self._create_vllm_config()
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward_oot(
self.positions_1d, self.query, self.key)
mock_npu_mrope.assert_called_once()
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)
@patch('torch_npu.npu_mrope')
@patch('vllm_ascend.platform.NPUPlatform.get_cpu_architecture')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_forward_oot_2d_positions(self, mock_cpu_arc, mock_npu_mrope):
mock_cpu_arc.return_value = CpuArchEnum.ARM
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
torch.zeros_like(self.key))
vllm_config = self._create_vllm_config()
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward_oot(
self.positions_2d, self.query, self.key)
mock_npu_mrope.assert_called_once()
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)

View File

@@ -31,7 +31,7 @@ from torch.library import Library
# 3. The registration utility will check if a meta implementation already exists for your op,
# and only register if necessary. This avoids duplicate registrations.
#
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
# 4. Example meta implementations are provided below for get_masked_input_and_mask.
#
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
@@ -52,25 +52,6 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
lib.impl(op_name, fn, "Meta")
def rotary_embedding_meta(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
):
num_tokens = positions.numel()
query_hidden_size = query.numel() // num_tokens
key_hidden_size = key.numel() // num_tokens
num_heads = query_hidden_size // head_size
num_kv_heads = key_hidden_size // head_size
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
return query_dst, key_dst
def get_masked_input_and_mask_meta(
input: torch.Tensor,
org_vocab_start_index: int,
@@ -105,7 +86,6 @@ def sgmv_expand_meta(
return y_out
register_meta_if_necessary("_C_ascend", "rotary_embedding", rotary_embedding_meta)
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta)
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)

View File

@@ -33,7 +33,7 @@ if HAS_TRITON:
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import AscendDeviceType, enable_custom_op, get_ascend_device_type, has_rope, is_vl_model
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
@@ -144,10 +144,6 @@ def get_cos_and_sin_slice():
return _cos_slice, _sin_slice
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op()
def _rope_forward_oot(
self,
positions: torch.Tensor,
@@ -162,9 +158,62 @@ def _rope_forward_oot(
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
cos, sin = get_cos_and_sin_slice()
# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(query, is_neox_style, self.head_size):
query, key = torch.ops._C_ascend.rotary_embedding(
if offsets is not None:
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
if (
is_neox_style
and self.head_size == 128
and self.cos_sin_cache.shape[-1] == 128
and cos is not None
and sin is not None
):
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
# Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
elif self.rotary_dim < self.head_size:
if HAS_TRITON:
cos = cos.view(-1, self.rotary_dim)
sin = sin.view(-1, self.rotary_dim)
q = query.contiguous().view(query.shape[0], -1, self.head_size)
k = key.contiguous().view(key.shape[0], -1, self.head_size)
query, key = torch.ops.vllm.rope_forward_triton(
q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
)
return query.view(query_shape), key.view(key_shape)
else:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., : self.rotary_dim]
q_pass = query[..., self.rotary_dim :]
k_rot = key[..., : self.rotary_dim]
k_pass = key[..., self.rotary_dim :]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
# only the rotary part is processed here,
# the dimension should be rotary_dim
torch_npu._npu_rotary_embedding(
positions,
q_rot,
k_rot,
self.rotary_dim,
self.cos_sin_cache,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
@@ -172,72 +221,7 @@ def _rope_forward_oot(
self.cos_sin_cache,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError("Batched rotary embedding is currently not supported on NPU.")
else:
if (
is_neox_style
and self.head_size == 128
and self.cos_sin_cache.shape[-1] == 128
and cos is not None
and sin is not None
):
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1, self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
# Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
elif self.rotary_dim < self.head_size:
if HAS_TRITON:
cos = cos.view(-1, self.rotary_dim)
sin = sin.view(-1, self.rotary_dim)
q = query.contiguous().view(query.shape[0], -1, self.head_size)
k = key.contiguous().view(key.shape[0], -1, self.head_size)
query, key = torch.ops.vllm.rope_forward_triton(
q, k, cos, sin, rope_dim=self.rotary_dim, is_neox_style=True
)
return query.view(query_shape), key.view(key_shape)
else:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., : self.rotary_dim]
q_pass = query[..., self.rotary_dim :]
k_rot = key[..., : self.rotary_dim]
k_pass = key[..., self.rotary_dim :]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
# only the rotary part is processed here,
# the dimension should be rotary_dim
torch_npu._npu_rotary_embedding(
positions,
q_rot,
k_rot,
self.rotary_dim,
self.cos_sin_cache,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
return query.view(query_shape), key.view(key_shape)
class AscendRotaryEmbedding(RotaryEmbedding):