2025-04-03 14:52:34 +08:00
|
|
|
/*
|
|
|
|
|
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
|
*
|
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
*
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include <torch/extension.h>
|
|
|
|
|
#include <torch/library.h>
|
|
|
|
|
#include <torch/version.h>
|
|
|
|
|
#include <torch_npu/csrc/core/npu/NPUStream.h>
|
|
|
|
|
#include <torch_npu/csrc/framework/OpCommand.h>
|
|
|
|
|
#include <torch_npu/csrc/npu/Module.h>
|
|
|
|
|
#include "acl/acl.h"
|
|
|
|
|
#include "ops.h"
|
|
|
|
|
#include "utils.h"
|
2025-10-12 07:39:45 +08:00
|
|
|
#include "mla_preprocess/op_host/mla_preprocess.h"
|
2025-12-08 19:22:14 +08:00
|
|
|
#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h"
|
2025-04-03 14:52:34 +08:00
|
|
|
|
|
|
|
|
namespace vllm_ascend {
|
|
|
|
|
|
2025-08-11 15:59:42 +08:00
|
|
|
AscendType get_dtype_from_torch(at::ScalarType scalarType)
|
|
|
|
|
{
|
|
|
|
|
if (scalarType == at::ScalarType::Float) {
|
|
|
|
|
return AscendType::FP32;
|
|
|
|
|
} else if (scalarType == at::ScalarType::BFloat16) {
|
|
|
|
|
return AscendType::BF16;
|
|
|
|
|
} else {
|
|
|
|
|
return AscendType::FP16;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-18 08:56:05 +08:00
|
|
|
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
2025-04-03 14:52:34 +08:00
|
|
|
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
|
|
|
|
|
{
|
|
|
|
|
int32_t deviceId = 0;
|
|
|
|
|
int64_t num_tokens = positions.numel();
|
|
|
|
|
int positions_ndim = positions.dim();
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
positions_ndim == 1 || positions_ndim == 2,
|
|
|
|
|
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
|
|
|
|
if (positions_ndim == 1) {
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
|
|
|
|
"query, key and positions must have the same number of tokens");
|
|
|
|
|
}
|
|
|
|
|
if (positions_ndim == 2) {
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
query.size(0) == positions.size(0) &&
|
|
|
|
|
key.size(0) == positions.size(0) &&
|
|
|
|
|
query.size(1) == positions.size(1) &&
|
|
|
|
|
key.size(1) == positions.size(1),
|
|
|
|
|
"query, key and positions must have the same batch_size and seq_len");
|
|
|
|
|
}
|
2025-04-18 08:56:05 +08:00
|
|
|
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
|
2025-04-03 14:52:34 +08:00
|
|
|
int query_hidden_size = query.numel() / num_tokens;
|
|
|
|
|
int key_hidden_size = key.numel() / num_tokens;
|
|
|
|
|
TORCH_CHECK(query_hidden_size % head_size == 0);
|
|
|
|
|
TORCH_CHECK(key_hidden_size % head_size == 0);
|
2025-04-18 08:56:05 +08:00
|
|
|
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
|
2025-04-03 14:52:34 +08:00
|
|
|
|
|
|
|
|
// Make sure query and key have consistent number of heads
|
|
|
|
|
int num_heads = query_hidden_size / head_size;
|
|
|
|
|
int num_kv_heads = key_hidden_size / head_size;
|
|
|
|
|
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
2025-04-18 08:56:05 +08:00
|
|
|
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
|
|
|
|
|
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
|
2025-04-03 14:52:34 +08:00
|
|
|
|
|
|
|
|
int rot_dim = cos_sin_cache.size(1);
|
2025-04-18 08:56:05 +08:00
|
|
|
int seq_dim_idx = positions_ndim - 1;
|
2025-04-03 14:52:34 +08:00
|
|
|
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
|
2025-04-18 08:56:05 +08:00
|
|
|
void *query_dst_ptr = query_dst.data_ptr();
|
|
|
|
|
void *key_dst_ptr = key_dst.data_ptr();
|
2025-04-03 14:52:34 +08:00
|
|
|
void *query_ptr = query.data_ptr();
|
|
|
|
|
void *key_ptr = key.data_ptr();
|
|
|
|
|
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
|
2025-04-18 08:56:05 +08:00
|
|
|
int64_t query_stride = query.stride(seq_dim_idx);
|
|
|
|
|
int64_t key_stride = key.stride(seq_dim_idx);
|
|
|
|
|
int64_t dst_query_stride = query_dst.stride(0);
|
|
|
|
|
int64_t dst_key_stride = key_dst.stride(0);
|
2025-04-03 14:52:34 +08:00
|
|
|
at::ScalarType scalar_type = query.scalar_type();
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("rotary_embedding");
|
2025-04-18 08:56:05 +08:00
|
|
|
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr,
|
2025-04-03 14:52:34 +08:00
|
|
|
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
|
2025-04-18 08:56:05 +08:00
|
|
|
dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int {
|
2025-04-03 14:52:34 +08:00
|
|
|
auto dtype_num = get_dtype_from_torch(scalar_type);
|
|
|
|
|
int device_id = 0;
|
2025-07-24 10:00:19 +08:00
|
|
|
int64_t aiv_num = 0;
|
|
|
|
|
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
|
|
|
|
uint32_t loop_cnt = (num_tokens + aiv_num - 1) / aiv_num;
|
2025-04-18 08:56:05 +08:00
|
|
|
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr,
|
|
|
|
|
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride,
|
2025-07-24 10:00:19 +08:00
|
|
|
dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num);
|
2025-04-03 14:52:34 +08:00
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
2025-04-18 08:56:05 +08:00
|
|
|
return {query_dst, key_dst};
|
2025-04-03 14:52:34 +08:00
|
|
|
}
|
2025-05-20 09:31:30 +08:00
|
|
|
|
2025-10-12 07:39:45 +08:00
|
|
|
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
2025-10-21 19:20:13 +08:00
|
|
|
const at::Tensor &hiddenState, const at::Tensor &wdqkv,
|
2025-10-12 07:39:45 +08:00
|
|
|
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
|
|
|
|
|
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
|
|
|
|
|
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
|
|
|
|
|
const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
|
|
|
|
|
const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
|
|
|
|
|
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
|
|
|
|
|
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
|
|
|
|
|
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
|
|
|
|
|
{
|
|
|
|
|
at::Tensor CtkvScale =
|
|
|
|
|
ctkv_scale.has_value()
|
|
|
|
|
? ctkv_scale.value()
|
|
|
|
|
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
|
|
|
|
|
at::Tensor QnopeScale =
|
|
|
|
|
q_nope_scale.has_value()
|
|
|
|
|
? q_nope_scale.value()
|
|
|
|
|
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
|
|
|
|
|
|
|
|
|
|
auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling(
|
|
|
|
|
hiddenState,
|
|
|
|
|
wuk,
|
|
|
|
|
cache_mode,
|
|
|
|
|
quant_mode
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
void *hidden_state_ptr = hiddenState.data_ptr();
|
|
|
|
|
void *quant_scale0_ptr = quant_scale0.data_ptr();
|
|
|
|
|
void *quant_offset0_ptr = quant_offset0.data_ptr();
|
|
|
|
|
void *wdqkv_ptr = wdqkv.data_ptr();
|
|
|
|
|
void *bias0_ptr = bias0.data_ptr();
|
|
|
|
|
void *gamma1_ptr = gamma1.data_ptr();
|
|
|
|
|
void *beta1_ptr = beta1.data_ptr();
|
|
|
|
|
void *quant_scale1_ptr = quant_scale1.data_ptr();
|
|
|
|
|
void *quant_offset1_ptr = quant_offset1.data_ptr();
|
|
|
|
|
void *gamma2_ptr = gamma2.data_ptr();
|
|
|
|
|
void *sin_ptr = sin.data_ptr();
|
|
|
|
|
void *cos_ptr = cos.data_ptr();
|
|
|
|
|
void *kv_cache_ptr = kv_cache.data_ptr();
|
|
|
|
|
void *slotmapping_ptr = slotmapping.data_ptr();
|
|
|
|
|
void *wuq_ptr = wuq.data_ptr();
|
|
|
|
|
void *bias1_ptr = bias1.data_ptr();
|
|
|
|
|
void *wuk_ptr = wuk.data_ptr();
|
|
|
|
|
void *descale0_ptr = descale0.data_ptr();
|
|
|
|
|
void *descale1_ptr = descale1.data_ptr();
|
|
|
|
|
void *ctkv_scale_ptr = CtkvScale.data_ptr();
|
|
|
|
|
void *qnope_scale_ptr = QnopeScale.data_ptr();
|
|
|
|
|
void *q_out0_ptr = q_out0.data_ptr();
|
|
|
|
|
void *kv_cache_out0_ptr = kv_cache_out0.data_ptr();
|
|
|
|
|
void *q_out1_ptr = q_out1.data_ptr();
|
|
|
|
|
void *kv_cache_out1_ptr = kv_cache_out1.data_ptr();
|
|
|
|
|
void *workspace_ptr = workspace_tensor.data_ptr();
|
|
|
|
|
void *tiling_ptr = tiling.data_ptr();
|
|
|
|
|
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("mla_preprocess");
|
|
|
|
|
|
2025-10-21 19:20:13 +08:00
|
|
|
cmd.SetCustomHandler([stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
2025-10-12 07:39:45 +08:00
|
|
|
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
|
|
|
|
|
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
|
|
|
|
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
|
|
|
|
tiling_ptr, block_dim]() -> int {
|
2025-10-21 19:20:13 +08:00
|
|
|
mla_preprocess_impl(stream, hidden_state_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
2025-10-12 07:39:45 +08:00
|
|
|
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
|
|
|
|
|
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
|
|
|
|
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
|
|
|
|
tiling_ptr, block_dim);
|
|
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
|
|
|
|
return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1);
|
|
|
|
|
}
|
|
|
|
|
|
2025-06-12 10:44:33 +08:00
|
|
|
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
|
|
|
|
at::Tensor &input,
|
|
|
|
|
const int64_t org_vocab_start_index,
|
|
|
|
|
const int64_t org_vocab_end_index,
|
|
|
|
|
const int64_t num_org_vocab_padding,
|
|
|
|
|
const int64_t added_vocab_start_index,
|
|
|
|
|
const int64_t added_vocab_end_index)
|
|
|
|
|
/*
|
|
|
|
|
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L161-L198
|
|
|
|
|
Embedding parallelized in the vocabulary dimension.
|
|
|
|
|
|
|
|
|
|
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
|
|
|
|
make sure it is divisible by the number of model parallel GPUs.
|
|
|
|
|
|
|
|
|
|
In order to support various loading methods, we ensure that LoRA-added
|
|
|
|
|
embeddings are always at the end of TP-sharded tensors. In other words,
|
|
|
|
|
we shard base embeddings and LoRA embeddings separately (both padded),
|
|
|
|
|
and place them in the same tensor.
|
|
|
|
|
In this example, we will have the original vocab size = 1010,
|
|
|
|
|
added vocab size = 16 and padding to 64. Therefore, the total
|
|
|
|
|
vocab size with padding will be 1088 (because we first pad 1010 to
|
|
|
|
|
1024, add 16, and then pad to 1088).
|
|
|
|
|
Therefore, the tensor format looks like the following:
|
|
|
|
|
TP1, rank 0 (no sharding):
|
|
|
|
|
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
|
|
|
|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
|
|
|
|
|
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
|
|
|
|
|
|
|
|
|
TP2, rank 0:
|
|
|
|
|
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
|
|
|
|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
|
|
|
|
|
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
|
|
|
|
|
TP2, rank 1:
|
|
|
|
|
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
|
|
|
|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
2025-09-13 11:58:52 +08:00
|
|
|
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
2025-06-12 10:44:33 +08:00
|
|
|
Parameters:
|
|
|
|
|
org_vocab_start_index //base embeddings start
|
|
|
|
|
org_vocab_end_index //base embeddings end
|
|
|
|
|
num_org_vocab_padding //base embeddings padding
|
|
|
|
|
added_vocab_start_index //LoRA embeddings start
|
|
|
|
|
added_vocab_end_index //LoRA embeddings end
|
|
|
|
|
*/
|
|
|
|
|
{
|
|
|
|
|
// Input validation
|
|
|
|
|
TORCH_CHECK(input.dim() >= 1, "input must have at least 1 dimension");
|
|
|
|
|
TORCH_CHECK(org_vocab_start_index >= 0, "org_vocab_start_index must be non-negative");
|
|
|
|
|
TORCH_CHECK(org_vocab_end_index >= org_vocab_start_index, "org_vocab_end_index must be greater than org_vocab_start_index");
|
|
|
|
|
TORCH_CHECK(num_org_vocab_padding >= 0, "num_org_vocab_padding must be non-negative");
|
|
|
|
|
TORCH_CHECK(added_vocab_start_index >= org_vocab_end_index, "added_vocab_start_index must be greater than org_vocab_end_index");
|
|
|
|
|
TORCH_CHECK(added_vocab_end_index >= added_vocab_start_index, "added_vocab_end_index must be greater than added_vocab_start_index");
|
|
|
|
|
|
|
|
|
|
// Get total number of elements
|
|
|
|
|
int64_t size = input.numel();
|
|
|
|
|
|
|
|
|
|
// Create output tensors
|
|
|
|
|
at::Tensor masked_input = at::empty_like(input);
|
|
|
|
|
at::Tensor mask = at::empty_like(input).to(at::kBool);
|
2025-09-13 11:58:52 +08:00
|
|
|
|
2025-06-12 10:44:33 +08:00
|
|
|
// Get data pointers
|
|
|
|
|
void *input_ptr = input.data_ptr();
|
|
|
|
|
void *masked_input_ptr = masked_input.data_ptr();
|
|
|
|
|
void *mask_ptr = mask.data_ptr();
|
2025-09-13 11:58:52 +08:00
|
|
|
|
2025-06-12 10:44:33 +08:00
|
|
|
// Get current stream
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
2025-09-13 11:58:52 +08:00
|
|
|
|
2025-06-12 10:44:33 +08:00
|
|
|
// Get scalar type
|
|
|
|
|
at::ScalarType scalar_type = input.scalar_type();
|
2025-09-13 11:58:52 +08:00
|
|
|
|
2025-06-12 10:44:33 +08:00
|
|
|
// Create and configure OpCommand
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("get_masked_input_and_mask");
|
2025-09-13 11:58:52 +08:00
|
|
|
cmd.SetCustomHandler([scalar_type, size, stream,
|
2025-06-12 10:44:33 +08:00
|
|
|
input_ptr, masked_input_ptr, mask_ptr,
|
|
|
|
|
org_vocab_start_index, org_vocab_end_index,
|
|
|
|
|
num_org_vocab_padding, added_vocab_start_index,
|
|
|
|
|
added_vocab_end_index]() -> int {
|
|
|
|
|
int device_id = 0;
|
2025-07-24 10:00:19 +08:00
|
|
|
int64_t aiv_num = 0;
|
|
|
|
|
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
|
|
|
|
uint32_t loop_cnt = (size + aiv_num - 1) / aiv_num;
|
|
|
|
|
|
2025-06-12 10:44:33 +08:00
|
|
|
// Call implementation
|
|
|
|
|
get_masked_input_and_mask_impl(
|
|
|
|
|
stream,
|
|
|
|
|
input_ptr,
|
2025-09-13 11:58:52 +08:00
|
|
|
masked_input_ptr,
|
2025-06-12 10:44:33 +08:00
|
|
|
mask_ptr,
|
|
|
|
|
org_vocab_start_index,
|
|
|
|
|
org_vocab_end_index,
|
|
|
|
|
num_org_vocab_padding,
|
|
|
|
|
added_vocab_start_index,
|
|
|
|
|
added_vocab_end_index,
|
|
|
|
|
size,
|
|
|
|
|
loop_cnt,
|
2025-07-24 10:00:19 +08:00
|
|
|
aiv_num);
|
2025-09-13 11:58:52 +08:00
|
|
|
|
2025-06-12 10:44:33 +08:00
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
|
|
|
|
return {masked_input, mask};
|
|
|
|
|
}
|
2025-07-29 19:27:50 +08:00
|
|
|
|
|
|
|
|
void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, double scale)
|
|
|
|
|
{
|
|
|
|
|
at::ScalarType scalar_type = x.scalar_type();
|
|
|
|
|
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
|
|
|
|
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
|
|
|
|
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
|
|
|
|
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
|
|
|
|
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
|
|
|
|
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
|
|
|
|
|
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
|
|
|
|
|
"the first dimension of x, y, indices should be same");
|
|
|
|
|
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
|
|
|
|
|
void* x_ptr = x.data_ptr();
|
|
|
|
|
void* weight_ptr = weight.data_ptr();
|
|
|
|
|
void* indices_ptr = indices.data_ptr();
|
2025-09-02 11:46:59 +08:00
|
|
|
int indices_size = indices.size(0);
|
2025-07-29 19:27:50 +08:00
|
|
|
void* y_ptr = y.data_ptr();
|
|
|
|
|
int batch_size = x.size(0);
|
|
|
|
|
int input_hidden_token = x.size(1);
|
|
|
|
|
uint32_t lora_rank = y.size(1);
|
|
|
|
|
float scale_f = static_cast<float>(scale);
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("bgmv_shrink");
|
2025-09-02 11:46:59 +08:00
|
|
|
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, input_hidden_token,
|
2025-07-29 19:27:50 +08:00
|
|
|
lora_rank, scale_f]() -> int {
|
|
|
|
|
auto dtype = get_dtype_from_torch(scalar_type);
|
|
|
|
|
int device_id = 0;
|
|
|
|
|
int64_t aiv_num = 0;
|
|
|
|
|
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
|
|
|
|
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
|
|
|
|
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
2025-09-02 11:46:59 +08:00
|
|
|
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, num_tokens_per_core,
|
2025-07-29 19:27:50 +08:00
|
|
|
input_hidden_token, lora_rank, scale_f);
|
|
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
|
|
|
|
|
int64_t slice_offset, int64_t slice_size)
|
|
|
|
|
{
|
|
|
|
|
at::ScalarType scalar_type = y.scalar_type();
|
|
|
|
|
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
|
|
|
|
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
|
|
|
|
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
|
|
|
|
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
|
|
|
|
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
|
|
|
|
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
|
|
|
|
|
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
|
|
|
|
|
"the first dimension of x, y, indices should be same");
|
|
|
|
|
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
|
|
|
|
|
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
|
|
|
|
|
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
|
|
|
|
|
"slice_size + slice_offset should be smaller than the second dimension of y")
|
|
|
|
|
|
|
|
|
|
at::Tensor y_out = y;
|
|
|
|
|
void* x_ptr = x.data_ptr();
|
|
|
|
|
void* weight_ptr = weight.data_ptr();
|
|
|
|
|
void* indices_ptr = indices.data_ptr();
|
2025-09-02 11:46:59 +08:00
|
|
|
int indices_size = indices.size(0);
|
2025-07-29 19:27:50 +08:00
|
|
|
void* y_ptr = y.data_ptr();
|
|
|
|
|
void* y_out_ptr = y_out.data_ptr();
|
|
|
|
|
int batch_size = x.size(0);
|
|
|
|
|
int lora_rank = x.size(1);
|
|
|
|
|
int output_full_dim = y.size(1);
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("bgmv_expand");
|
2025-09-02 11:46:59 +08:00
|
|
|
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size, lora_rank,
|
2025-07-29 19:27:50 +08:00
|
|
|
slice_offset, slice_size, output_full_dim]() -> int {
|
|
|
|
|
auto dtype = get_dtype_from_torch(scalar_type);
|
|
|
|
|
int device_id = 0;
|
|
|
|
|
int64_t aiv_num = 0;
|
|
|
|
|
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
|
|
|
|
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
|
|
|
|
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
2025-09-02 11:46:59 +08:00
|
|
|
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size,
|
2025-07-29 19:27:50 +08:00
|
|
|
num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
|
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
|
|
|
|
return y_out;
|
|
|
|
|
}
|
2025-08-19 09:09:11 +08:00
|
|
|
|
|
|
|
|
void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
|
|
|
|
|
at::Tensor &y, double scale)
|
|
|
|
|
{
|
|
|
|
|
at::ScalarType scalar_type = x.scalar_type();
|
|
|
|
|
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
|
|
|
|
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
|
|
|
|
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
|
|
|
|
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
|
|
|
|
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
|
|
|
|
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
|
|
|
|
|
void* x_ptr = x.data_ptr();
|
|
|
|
|
void* weight_ptr = weight.data_ptr();
|
|
|
|
|
void* lora_indices_ptr = lora_indices.data_ptr();
|
|
|
|
|
void* seq_len_ptr = seq_len.data_ptr();
|
2025-09-02 11:46:59 +08:00
|
|
|
int lora_indices_size = lora_indices.size(0);
|
|
|
|
|
int seq_len_size = seq_len.size(0);
|
2025-08-19 09:09:11 +08:00
|
|
|
void* y_ptr = y.data_ptr();
|
|
|
|
|
int batch_size = x.size(0);
|
|
|
|
|
int input_hidden_token = x.size(1);
|
|
|
|
|
uint32_t lora_rank = y.size(1);
|
|
|
|
|
float scale_f = static_cast<float>(scale);
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("sgmv_shrink");
|
2025-09-13 11:58:52 +08:00
|
|
|
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
|
|
|
|
|
seq_len_ptr, seq_len_size, y_ptr,
|
2025-08-19 09:09:11 +08:00
|
|
|
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
|
|
|
|
|
auto dtype = get_dtype_from_torch(scalar_type);
|
|
|
|
|
int device_id = 0;
|
|
|
|
|
int64_t aiv_num = 0;
|
|
|
|
|
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
|
|
|
|
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
|
|
|
|
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
2025-09-02 11:46:59 +08:00
|
|
|
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
|
2025-09-13 11:58:52 +08:00
|
|
|
y_ptr, batch_size,
|
2025-08-19 09:09:11 +08:00
|
|
|
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
|
|
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
|
|
|
|
|
at::Tensor &y, int64_t slice_offset, int64_t slice_size)
|
|
|
|
|
{
|
|
|
|
|
at::ScalarType scalar_type = y.scalar_type();
|
|
|
|
|
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
|
|
|
|
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
|
|
|
|
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
|
|
|
|
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
|
|
|
|
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
|
|
|
|
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
|
|
|
|
|
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
|
|
|
|
|
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
|
|
|
|
|
"slice_size + slice_offset should be smaller than the second dimension of y")
|
|
|
|
|
|
|
|
|
|
at::Tensor y_out = y;
|
|
|
|
|
void* x_ptr = x.data_ptr();
|
|
|
|
|
void* weight_ptr = weight.data_ptr();
|
|
|
|
|
void* lora_indices_ptr = lora_indices.data_ptr();
|
|
|
|
|
void* seq_len_ptr = seq_len.data_ptr();
|
2025-09-02 11:46:59 +08:00
|
|
|
int lora_indices_size = lora_indices.size(0);
|
|
|
|
|
int seq_len_size = seq_len.size(0);
|
2025-08-19 09:09:11 +08:00
|
|
|
void* y_ptr = y.data_ptr();
|
|
|
|
|
void* y_out_ptr = y_out.data_ptr();
|
|
|
|
|
int batch_size = x.size(0);
|
|
|
|
|
int lora_rank = x.size(1);
|
|
|
|
|
int output_full_dim = y.size(1);
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("sgmv_expand");
|
2025-09-13 11:58:52 +08:00
|
|
|
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
2025-08-19 09:09:11 +08:00
|
|
|
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
|
|
|
|
|
auto dtype = get_dtype_from_torch(scalar_type);
|
|
|
|
|
int device_id = 0;
|
|
|
|
|
int64_t aiv_num = 0;
|
|
|
|
|
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
|
|
|
|
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
|
|
|
|
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
2025-09-13 11:58:52 +08:00
|
|
|
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
2025-08-19 09:09:11 +08:00
|
|
|
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
|
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
|
|
|
|
return y_out;
|
|
|
|
|
}
|
2025-12-08 19:22:14 +08:00
|
|
|
|
|
|
|
|
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
|
|
|
|
c10::optional<c10::string_view> format_mode,
|
|
|
|
|
c10::optional<c10::string_view> quant_mode)
|
|
|
|
|
{
|
|
|
|
|
auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling(
|
|
|
|
|
tensor_a,
|
|
|
|
|
tensor_b,
|
|
|
|
|
tensor_c,
|
|
|
|
|
format_mode,
|
|
|
|
|
quant_mode
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
void *gm_a = tensor_a.data_ptr();
|
|
|
|
|
void *gm_b = tensor_b.data_ptr();
|
|
|
|
|
void *gm_c = tensor_c.data_ptr();
|
|
|
|
|
void *gm_tiling_data = tiling_tensor.data_ptr();
|
|
|
|
|
|
|
|
|
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
|
|
|
|
at_npu::native::OpCommand cmd;
|
|
|
|
|
cmd.Name("batch_matmul_transpose");
|
|
|
|
|
|
|
|
|
|
cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
|
|
|
|
block_dim]() -> int {
|
|
|
|
|
batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
|
|
|
|
block_dim);
|
|
|
|
|
return 0;
|
|
|
|
|
});
|
|
|
|
|
cmd.Run();
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-03 14:52:34 +08:00
|
|
|
} // namespace vllm_ascend
|
|
|
|
|
|
2025-09-13 11:58:52 +08:00
|
|
|
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
2025-04-03 14:52:34 +08:00
|
|
|
{
|
|
|
|
|
// vLLM-Ascend custom ops
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
|
|
|
|
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
2025-04-03 14:52:34 +08:00
|
|
|
|
|
|
|
|
// Rotary embedding
|
|
|
|
|
// Apply GPT-NeoX style rotary embedding to query and key.
|
|
|
|
|
ops.def(
|
|
|
|
|
"rotary_embedding(Tensor positions, Tensor! query,"
|
|
|
|
|
" Tensor! key, int head_size,"
|
2025-04-18 08:56:05 +08:00
|
|
|
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
2025-04-03 14:52:34 +08:00
|
|
|
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
2025-06-12 10:44:33 +08:00
|
|
|
|
|
|
|
|
ops.def(
|
|
|
|
|
"get_masked_input_and_mask(Tensor input, "
|
|
|
|
|
" int org_vocab_start_index, "
|
|
|
|
|
" int org_vocab_end_index, "
|
|
|
|
|
" int num_org_vocab_padding, "
|
|
|
|
|
" int added_vocab_start_index, "
|
|
|
|
|
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
|
|
|
|
|
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
|
2025-07-29 19:27:50 +08:00
|
|
|
|
|
|
|
|
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
|
|
|
|
|
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
|
|
|
|
|
|
|
|
|
|
ops.def(
|
|
|
|
|
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
|
|
|
|
|
" int slice_offset, int slice_size) -> Tensor");
|
|
|
|
|
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
|
2025-08-19 09:09:11 +08:00
|
|
|
|
|
|
|
|
ops.def("sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");
|
|
|
|
|
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
|
|
|
|
|
|
|
|
|
|
ops.def(
|
|
|
|
|
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
|
|
|
|
|
" int slice_offset, int slice_size) -> Tensor");
|
|
|
|
|
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
2025-10-12 07:39:45 +08:00
|
|
|
|
|
|
|
|
ops.def(
|
2025-10-21 19:20:13 +08:00
|
|
|
"mla_preprocess(Tensor hiddenState, Tensor wdqkv,"
|
2025-10-12 07:39:45 +08:00
|
|
|
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
|
|
|
|
|
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
|
|
|
|
|
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
|
|
|
|
|
" Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1,"
|
|
|
|
|
" Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
|
|
|
|
|
" str? quant_mode, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
|
|
|
|
|
" Tensor! kv_cache_out1) -> (Tensor q_out0, Tensor kv_cache_out0,"
|
|
|
|
|
" Tensor q_out1, Tensor kv_cache_out1)"
|
|
|
|
|
);
|
|
|
|
|
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
|
2025-12-08 19:22:14 +08:00
|
|
|
//batch_matmul ops refer to sgl-kernel-npu
|
|
|
|
|
ops.def(
|
|
|
|
|
"batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()");
|
|
|
|
|
ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose);
|
2025-04-03 14:52:34 +08:00
|
|
|
}
|