[Cherry-pick]bmm_transpose to v011dev (#3995)

### What this PR does / why we need it?
Add a custom op to acclerater the deepseek model. The fusion ops combine
the bmm and transpose together, which is applied to mla module.
Cherry-pick from this commtid c68ddc11ce53334fc9a17bad58342148cbf14e86

### Does this PR introduce _any_ user-facing change?
No

---------

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-12-08 19:22:14 +08:00
committed by GitHub
parent 6391f0625f
commit d412565ec9
15 changed files with 1736 additions and 13 deletions

View File

@@ -24,6 +24,7 @@
#include "ops.h"
#include "utils.h"
#include "mla_preprocess/op_host/mla_preprocess.h"
#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h"
namespace vllm_ascend {
@@ -458,6 +459,39 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
cmd.Run();
return y_out;
}
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
c10::optional<c10::string_view> format_mode,
c10::optional<c10::string_view> quant_mode)
{
auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling(
tensor_a,
tensor_b,
tensor_c,
format_mode,
quant_mode
);
void *gm_a = tensor_a.data_ptr();
void *gm_b = tensor_b.data_ptr();
void *gm_c = tensor_c.data_ptr();
void *gm_tiling_data = tiling_tensor.data_ptr();
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("batch_matmul_transpose");
cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data,
block_dim]() -> int {
batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data,
block_dim);
return 0;
});
cmd.Run();
return;
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
@@ -511,4 +545,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" Tensor q_out1, Tensor kv_cache_out1)"
);
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
//batch_matmul ops refer to sgl-kernel-npu
ops.def(
"batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()");
ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose);
}