[OP] add custom op aclnnMoeInitRoutingCustom (#5251)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

This pull request introduces a new custom operator
`aclnnMoeInitRoutingCustom` for Mixture-of-Experts models.
It can be replaced by `aclnnMoeInitRoutingV3` once CANN 8.5 becomes
available.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
No.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: jiazhengyi <jiazhengyi@huawei.com>
Signed-off-by: Chenxi Qian <chenxi.qian.cq@outlook.com>
Co-authored-by: jiazhengyi <jiazhengyi@huawei.com>
Co-authored-by: Chenxi Qian <chenxi.qian.cq@outlook.com>
This commit is contained in:
jiazhengyi
2025-12-29 19:29:40 +08:00
committed by GitHub
parent 92353c0643
commit d5f72835e6
40 changed files with 10815 additions and 1 deletions

View File

@@ -1118,6 +1118,106 @@ at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, cons
return combined_x;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_moe_init_routing_custom(
const at::Tensor &x, const at::Tensor &expert_idx,
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
int64_t expert_capacity, int64_t expert_num, int64_t drop_pad_mode, int64_t expert_tokens_num_type,
bool expert_tokens_num_flag, int64_t quant_mode, at::IntArrayRef active_expert_range, int64_t row_idx_type)
{
constexpr int64_t DIM_X = 2;
constexpr int64_t DIM_EXPERT_IDX = 2;
constexpr int64_t LENGTH_ACTIVE_EXPERT_RANGE = 2;
constexpr int64_t EXPERT_TOKENS_COUNT = 1;
constexpr int64_t EXPERT_TOKENS_KEY_VALUE = 2;
constexpr int64_t QUANT_MODE_UNQUANT = -1;
constexpr int64_t QUANT_MODE_DYNAMIC_QUANT = 1;
constexpr int64_t CUMSUM = 0;
constexpr int64_t COUNT = 1;
constexpr int64_t KEY_VALUE = 2;
if (active_expert_range.empty()) {
active_expert_range = at::IntArrayRef({0, expert_num});
}
int64_t x_dim = x.dim();
TORCH_CHECK(x_dim == DIM_X, "The x should be ", DIM_X,
"-Dimension, current is ", x_dim, "-Dimension.");
int64_t expert_idx_dim = expert_idx.dim();
TORCH_CHECK(expert_idx_dim == DIM_EXPERT_IDX, "The expert_idx should be ", DIM_EXPERT_IDX,
"-Dimension, current is ", expert_idx_dim, "-Dimension.");
int64_t active_expert_range_length = active_expert_range.size();
TORCH_CHECK(active_expert_range_length == LENGTH_ACTIVE_EXPERT_RANGE, "The active_expert_range should be ", LENGTH_ACTIVE_EXPERT_RANGE,
"-Dimension, current is ", expert_idx_dim, "-Dimension.");
int expert_length = active_expert_range[1] - active_expert_range[0];
auto x_size = x.sizes();
auto expert_idx_size = expert_idx.sizes();
int bs = x_size[0];
int h = x_size[1];
int k = expert_idx_size[1];
int64_t expanded_scale_len = 0;
at::Tensor expanded_x;
if (drop_pad_mode == 1) { // Drop/Pad
if (quant_mode == QUANT_MODE_UNQUANT) {
expanded_x = at::empty({expert_num, expert_capacity, h}, x.options());
} else {
expanded_x = at::empty({expert_num, expert_capacity, h}, x.options().dtype(at::kChar));
}
expanded_scale_len = expert_num * expert_capacity;
} else { // Dropless / Active
if (active_num > 0) { // Active
int64_t num_out_tokens = std::min((int64_t)bs * k, active_num);
if (quant_mode == QUANT_MODE_UNQUANT) {
expanded_x = at::empty({num_out_tokens, h}, x.options());
} else {
expanded_x = at::empty({num_out_tokens, h}, x.options().dtype(at::kChar));
}
expanded_scale_len = num_out_tokens;
} else { // Dropless
if (quant_mode == QUANT_MODE_UNQUANT) {
expanded_x = at::empty({bs * k, h}, x.options());
} else {
expanded_x = at::empty({bs * k, h}, x.options().dtype(at::kChar));
}
expanded_scale_len = bs * k;
}
}
at::Tensor expanded_row_idx = at::empty({bs * k}, expert_idx.options());
at::Tensor expert_tokens_count_or_cumsum;
if (expert_tokens_num_type >= CUMSUM && expert_tokens_num_type <= COUNT) {
// expert_tokens_count_or_cumsum in [end-start, ]
expert_tokens_count_or_cumsum = at::empty({expert_length}, x.options().dtype(at::kLong));
} else if (expert_tokens_num_type == KEY_VALUE) {
// key_value in [2, end-start]
expert_tokens_count_or_cumsum = at::empty({expert_num, 2}, x.options().dtype(at::kLong));
}
at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat));
EXEC_NPU_CMD(aclnnMoeInitRoutingCustom,
x,
expert_idx,
scale,
offset,
active_num,
expert_capacity,
expert_num,
drop_pad_mode,
expert_tokens_num_type,
expert_tokens_num_flag,
quant_mode,
active_expert_range,
row_idx_type,
expanded_x,
expanded_row_idx,
expert_tokens_count_or_cumsum,
expanded_scale);
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
@@ -1257,4 +1357,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"num_ranks) -> Tensor");
ops.impl("combine_prefill", torch::kPrivateUse1,
&vllm_ascend::combine_prefill);
ops.def(
"npu_moe_init_routing_custom(Tensor x, Tensor expert_idx, *, Tensor? scale=None, Tensor? offset=None, int active_num=-1, "
" int expert_capacity=-1, int expert_num=-1, int drop_pad_mode=0, int expert_tokens_num_type=0, "
" bool expert_tokens_num_flag=False, int quant_mode=0, int[2] active_expert_range=[], "
" int row_idx_type=0) -> (Tensor, Tensor, Tensor, Tensor)"
);
ops.impl("npu_moe_init_routing_custom", torch::kPrivateUse1, &vllm_ascend::npu_moe_init_routing_custom);
}