[MOE] commit GMM custom operator (#7010)
### What this PR does / why we need it?
GMM custom operator optimization in small batch scenarios
### How was this patch tested?
Submit the GMM custom operator for subsequent integration into the MOE
process.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
This commit is contained in:
@@ -597,6 +597,38 @@ void transpose_kv_cache_by_block(
|
||||
|
||||
}
|
||||
|
||||
// It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
|
||||
std::vector<at::Tensor> moe_grouped_matmul(
|
||||
at::Tensor x,
|
||||
at::Tensor weight,
|
||||
const at::Tensor& group_list,
|
||||
int64_t split_item,
|
||||
int64_t group_type,
|
||||
int64_t group_list_type
|
||||
)
|
||||
{
|
||||
bool transpose_weight = false;
|
||||
bool weight_nz = true;
|
||||
|
||||
at::TensorList x_list = at::TensorList(x);
|
||||
at::TensorList weight_list = at::TensorList(weight);
|
||||
std::vector<at::Tensor> y;
|
||||
c10::TensorOptions options = x_list[0].options().dtype(x[0].scalar_type());
|
||||
auto m = x_list[0].sizes()[0];
|
||||
auto n = weight_list[0].sizes()[1];
|
||||
if (!transpose_weight) {
|
||||
n = weight_list[0].sizes()[2];
|
||||
}
|
||||
at::Tensor y_0 = at::empty(at::IntArrayRef{m, n}, options);
|
||||
y.emplace_back(y_0);
|
||||
at::TensorList result = at::TensorList(y);
|
||||
|
||||
EXEC_NPU_CMD(aclnnMoeGroupedMatmulWeightNz,
|
||||
x_list, weight_list, group_list, transpose_weight, result);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -779,4 +811,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
|
||||
);
|
||||
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
|
||||
ops.def(
|
||||
"moe_grouped_matmul("
|
||||
"Tensor x,"
|
||||
"Tensor weight,"
|
||||
"Tensor group_list,"
|
||||
"int split_item,"
|
||||
"int group_type,"
|
||||
"int group_list_type)"
|
||||
|
||||
"-> Tensor[]"
|
||||
);
|
||||
ops.impl("moe_grouped_matmul", torch::kPrivateUse1,&vllm_ascend::moe_grouped_matmul);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user