[kernel] add AscendC op: lightning_indexer and sparse_flash_attention (#4625)
### What this PR does / why we need it? Provide high-performance AscendC operators lightning_indexer and sparse_flash_attention to boost the execution performance of the DeepSeek v3.2 model. Meanwhile, adapt the two AscendC operators to vllm-ascend framework. ### Does this PR introduce _any_ user-facing change? No (only underlying operator optimizations, with no user-facing changes) ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: MingYang119 <songmingyang@huawei.com>
This commit is contained in:
@@ -620,6 +620,103 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
||||
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_key,
|
||||
const c10::optional<at::Tensor> &block_table, c10::string_view layout_query,
|
||||
c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
|
||||
{
|
||||
// npu tensor max size
|
||||
constexpr int32_t SIZE = 8;
|
||||
constexpr int32_t DIM_0 = 0;
|
||||
constexpr int32_t DIM_1 = 1;
|
||||
constexpr int32_t DIM_2 = 2;
|
||||
constexpr int32_t DIM_3 = 3;
|
||||
|
||||
TORCH_CHECK(query.numel() > 0, "Query is empty.");
|
||||
TORCH_CHECK(key.numel() > 0, "Key is empty.");
|
||||
TORCH_CHECK(weights.numel() > 0, "Weights is empty.");
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
|
||||
|
||||
at::SmallVector<int64_t, SIZE> output_size;
|
||||
std::string query_layout_str = std::string(layout_query);
|
||||
std::string key_layout_str = std::string(layout_key);
|
||||
if (query_layout_str == "BSND") {
|
||||
output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count};
|
||||
} else {
|
||||
int n_dim_index = 0;
|
||||
n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2;
|
||||
output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count};
|
||||
}
|
||||
at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt));
|
||||
// convert str
|
||||
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
|
||||
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
|
||||
EXEC_NPU_CMD(
|
||||
aclnnLightningIndexer,
|
||||
query,
|
||||
key,
|
||||
weights,
|
||||
actual_seq_lengths_query,
|
||||
actual_seq_lengths_key,
|
||||
block_table,
|
||||
query_layout_ptr,
|
||||
key_layout_ptr,
|
||||
sparse_count,
|
||||
sparse_mode,
|
||||
lightning_indexer_output);
|
||||
return lightning_indexer_output;
|
||||
}
|
||||
|
||||
at::Tensor npu_sparse_flash_attention(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &value,
|
||||
const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size,
|
||||
const c10::optional<at::Tensor> &block_table,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_kv,
|
||||
const c10::optional<at::Tensor> &query_rope,
|
||||
const c10::optional<at::Tensor> &key_rope, c10::string_view layout_query,
|
||||
c10::string_view layout_kv,
|
||||
int64_t sparse_mode)
|
||||
{
|
||||
std::string layout_query_str = std::string(layout_query);
|
||||
std::string layout_kv_str = std::string(layout_kv);
|
||||
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
// construct the output tensor
|
||||
at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype()));
|
||||
// convert str
|
||||
char *layout_query_ptr = const_cast<char *>(layout_query_str.c_str());
|
||||
char *layout_kv_ptr = const_cast<char *>(layout_kv_str.c_str());
|
||||
|
||||
EXEC_NPU_CMD(
|
||||
aclnnSparseFlashAttention,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sparse_indices,
|
||||
block_table,
|
||||
actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv,
|
||||
query_rope,
|
||||
key_rope,
|
||||
scale_value,
|
||||
sparse_block_size,
|
||||
layout_query_ptr,
|
||||
layout_kv_ptr,
|
||||
sparse_mode,
|
||||
output);
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -695,4 +792,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" (Tensor output, Tensor output_scale, Tensor output_offset)"
|
||||
);
|
||||
ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list);
|
||||
|
||||
ops.def(
|
||||
"npu_lightning_indexer(Tensor query, Tensor key, Tensor weights, *,"
|
||||
" Tensor? actual_seq_lengths_query=None, Tensor? actual_seq_lengths_key=None,"
|
||||
" Tensor? block_table=None, str layout_query='BSND', str layout_key='BSND',"
|
||||
" int sparse_count=2048, int sparse_mode=3) -> Tensor"
|
||||
);
|
||||
ops.impl("npu_lightning_indexer", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer);
|
||||
|
||||
ops.def(
|
||||
"npu_sparse_flash_attention(Tensor query, Tensor key, Tensor value,"
|
||||
" Tensor sparse_indices, float scale_value, int sparse_block_size, *,"
|
||||
" Tensor? block_table=None, Tensor? actual_seq_lengths_query=None,"
|
||||
" Tensor? actual_seq_lengths_kv=None, Tensor? query_rope=None,"
|
||||
" Tensor? key_rope=None, str layout_query='BSND', str layout_kv='BSND',"
|
||||
" int sparse_mode=3) -> Tensor"
|
||||
);
|
||||
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user