[OPS] add bmm_transpose ops (#3990)

### What this PR does / why we need it?
Add a new fusion ops to custom_op, which can cobime the torch.bmm() and
transpsose to achieve better peformance. This ops is used in mla_v1 to
replace the bmm and transpose

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?


- vLLM version: v0.11.2

---------

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-12-01 09:09:51 +08:00
committed by GitHub
parent bc67696a02
commit c68ddc11ce
15 changed files with 1737 additions and 14 deletions

View File

@@ -27,6 +27,7 @@
#include "ops.h"
#include "utils.h"
#include "mla_preprocess/op_host/mla_preprocess.h"
#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h"
#include "aclnn_torch_adapter/op_api_common.h"
#include <c10/core/Device.h>
@@ -587,6 +588,38 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
}
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
c10::optional<c10::string_view> format_mode,
c10::optional<c10::string_view> quant_mode)
{
auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling(
tensor_a,
tensor_b,
tensor_c,
format_mode,
quant_mode
);
void *gm_a = tensor_a.data_ptr();
void *gm_b = tensor_b.data_ptr();
void *gm_c = tensor_c.data_ptr();
void *gm_tiling_data = tiling_tensor.data_ptr();
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("batch_matmul_transpose");
cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data,
block_dim]() -> int {
batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data,
block_dim);
return 0;
});
cmd.Run();
return;
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
@@ -641,6 +674,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
);
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
//batch_matmul ops refer to sgl-kernel-npu
ops.def(
"batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()");
ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose);
ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()");
ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks);