[Cherry-pick]bmm_transpose to v011dev (#3995)
### What this PR does / why we need it? Add a custom op to acclerater the deepseek model. The fusion ops combine the bmm and transpose together, which is applied to mla module. Cherry-pick from this commtid c68ddc11ce53334fc9a17bad58342148cbf14e86 ### Does this PR introduce _any_ user-facing change? No --------- Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -24,6 +24,7 @@
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
#include "mla_preprocess/op_host/mla_preprocess.h"
|
||||
#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
@@ -458,6 +459,39 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
|
||||
void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
|
||||
c10::optional<c10::string_view> format_mode,
|
||||
c10::optional<c10::string_view> quant_mode)
|
||||
{
|
||||
auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling(
|
||||
tensor_a,
|
||||
tensor_b,
|
||||
tensor_c,
|
||||
format_mode,
|
||||
quant_mode
|
||||
);
|
||||
|
||||
void *gm_a = tensor_a.data_ptr();
|
||||
void *gm_b = tensor_b.data_ptr();
|
||||
void *gm_c = tensor_c.data_ptr();
|
||||
void *gm_tiling_data = tiling_tensor.data_ptr();
|
||||
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("batch_matmul_transpose");
|
||||
|
||||
cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
||||
block_dim]() -> int {
|
||||
batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data,
|
||||
block_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -511,4 +545,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" Tensor q_out1, Tensor kv_cache_out1)"
|
||||
);
|
||||
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
|
||||
//batch_matmul ops refer to sgl-kernel-npu
|
||||
ops.def(
|
||||
"batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()");
|
||||
ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user