[OPS] add bmm_transpose ops (#3990)
### What this PR does / why we need it? Add a new fusion ops to custom_op, which can cobime the torch.bmm() and transpsose to achieve better peformance. This ops is used in mla_v1 to replace the bmm and transpose ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.11.2 --------- Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -27,6 +27,7 @@
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
#include "mla_preprocess/op_host/mla_preprocess.h"
|
||||
#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h"
|
||||
#include "aclnn_torch_adapter/op_api_common.h"
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
@@ -587,6 +588,38 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> grouped_matmul_swiglu_quant_weigh
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(output, output_scale, output_offset);
|
||||
}
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
{
|
||||
auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling(
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
tensor_c,
|
||||
format_mode,
|
||||
quant_mode
|
||||
);
|
||||
|
||||
void *gm_a = tensor_a.data_ptr();
|
||||
void *gm_b = tensor_b.data_ptr();
|
||||
void *gm_c = tensor_c.data_ptr();
|
||||
void *gm_tiling_data = tiling_tensor.data_ptr();
|
||||
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("batch_matmul_transpose");
|
||||
|
||||
cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
||||
block_dim]() -> int {
|
||||
batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
||||
block_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -641,6 +674,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
);
|
||||
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
|
||||
|
||||
//batch_matmul ops refer to sgl-kernel-npu
|
||||
ops.def(
|
||||
"batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()");
|
||||
ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose);
|
||||
|
||||
ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()");
|
||||
ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user