Files
xc-llm-ascend/tests/e2e/nightly/ops/test_moe_init_routing_custom.py
jiazhengyi d5f72835e6 [OP] add custom op aclnnMoeInitRoutingCustom (#5251)
<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

This pull request introduces a new custom operator
`aclnnMoeInitRoutingCustom` for Mixture-of-Experts models.
It can be replaced by `aclnnMoeInitRoutingV3` once CANN 8.5 becomes
available.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
No.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: jiazhengyi <jiazhengyi@huawei.com>
Signed-off-by: Chenxi Qian <chenxi.qian.cq@outlook.com>
Co-authored-by: jiazhengyi <jiazhengyi@huawei.com>
Co-authored-by: Chenxi Qian <chenxi.qian.cq@outlook.com>
2025-12-29 19:29:40 +08:00

350 lines
14 KiB
Python

import itertools
import random
import numpy as np
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
def adapter_capacity(sorted_row_idx, sorted_expert_idx, capacity):
count = 0
last = sorted_expert_idx[0]
for i, val in enumerate(sorted_expert_idx):
if last != val:
count = 1
last = val
else:
count += 1
if count > capacity:
sorted_expert_idx[i] = -1
sorted_row_idx[i] = -1
def moe_init_routing_golden(x, expert_idx, scale, offset, active_num,
expert_capacity, expert_num, drop_pad_mode,
expert_tokens_num_type, expert_tokens_num_flag,
active_expert_range, quant_mode, row_idx_type):
if drop_pad_mode == 1:
if expert_num <= 0:
print("expert num can not be 0")
return
expert_start = active_expert_range[0] if drop_pad_mode == 0 else 0
expert_end = active_expert_range[1] if drop_pad_mode == 0 else expert_num
num_rows = x.shape[0]
h = x.shape[1]
k = expert_idx.shape[-1]
expert_idx_in = expert_idx.copy().reshape(-1)
actual_expert_total_num: int = np.sum((expert_idx_in >= expert_start)
& (expert_idx_in < expert_end))
expert_idx_in[(expert_idx_in
< expert_start)] = np.int32(np.iinfo(np.int32).max)
sorted_expert_indices = np.argsort(expert_idx_in, axis=-1, kind="stable")
sorted_expert_idx = expert_idx_in[sorted_expert_indices]
if row_idx_type == 1:
expanded_row_idx = sorted_expert_indices[:actual_expert_total_num]
else:
expanded_row_idx = np.ones(num_rows * k).astype(np.int32) * -1
tmp_indices = np.arange(actual_expert_total_num)
expanded_row_idx[
sorted_expert_indices[:actual_expert_total_num]] = tmp_indices
if not expert_tokens_num_flag:
expert_tokens_count = torch.tensor([0])
else:
if drop_pad_mode == 0:
if expert_tokens_num_type == 1:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
expert_tokens_count = np.concatenate([
expert_tokens_count,
np.zeros((expert_end - expert_start) -
len(expert_tokens_count)).astype(np.int64)
])
elif expert_tokens_num_type == 0:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
expert_tokens_count = np.concatenate([
expert_tokens_count,
np.zeros((expert_end - expert_start) -
len(expert_tokens_count)).astype(np.int64)
])
expert_tokens_count = np.cumsum(expert_tokens_count)
elif expert_tokens_num_type == 2:
expert_id, counts = np.unique(
sorted_expert_idx[:actual_expert_total_num],
return_counts=True)
expert_tokens_count = np.column_stack((expert_id, counts))
if expert_tokens_count.shape[0] < expert_num:
expert_tokens_count = np.concatenate(
(expert_tokens_count, [
[0, 0],
]), axis=0)
else:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
zeros_array = np.zeros(
(expert_end - expert_start) - len(expert_tokens_count),
dtype=np.int64)
expert_tokens_count = np.concatenate(
[expert_tokens_count, zeros_array])
expert_tokens_count = expert_tokens_count.astype(np.int64)
if drop_pad_mode == 0:
if active_num == 0:
active_num = actual_expert_total_num
else:
active_num = min(active_num, actual_expert_total_num)
expanded_scale = None
expanded_x = x[sorted_expert_indices[:active_num] // k, :]
if scale is not None and quant_mode == -1:
expanded_scale = scale[sorted_expert_indices[:active_num] // k]
else:
adapter_capacity(sorted_expert_indices, sorted_expert_idx,
expert_capacity)
sort_row_tmp = np.full((expert_num * expert_capacity), -1, dtype=int)
offset_tmp = 0
lastExpertId = 0
for i, val in enumerate(sorted_expert_indices):
if val != -1:
if lastExpertId != sorted_expert_idx[i]:
offset_tmp = 0
lastExpertId = sorted_expert_idx[i]
sort_row_tmp[sorted_expert_idx[i] * expert_capacity +
offset_tmp] = sorted_expert_indices[i]
offset_tmp = offset_tmp + 1
expanded_row_idx = np.full(sorted_expert_indices.shape, -1)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_row_idx[val] = i
expanded_x_mask = np.full((expert_num * expert_capacity, h),
1,
dtype=int)
expanded_x = np.full((expert_num * expert_capacity, h),
0,
dtype=x.dtype)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_x[i] = x[val // k]
expanded_x_mask[i] = np.full((h, ), 0, dtype=int)
if quant_mode == -1:
expanded_x = expanded_x
expanded_row_idx = expanded_row_idx
if scale is not None and drop_pad_mode == 1:
expanded_scale = np.full((expert_num * expert_capacity, ),
0,
dtype=scale.dtype)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_scale[i] = scale[val // k]
if scale is None:
expanded_scale = None
if quant_mode == 0:
expanded_scale = None
expanded_x_fp16 = expanded_x.astype(np.float16)
if scale is not None:
scale_val = scale.astype(np.float16)
else:
raise ValueError("scale cannot be None when quant_mode is 0")
if offset is not None:
offset_val = offset.astype(np.float16)
else:
raise ValueError("offset cannot be None when quant_mode is 0")
scale_rst = expanded_x_fp16 * scale_val[0]
add_offset = scale_rst + offset_val[0]
round_data = np.rint(add_offset)
round_data = np.clip(round_data, -128, 127)
expanded_x = round_data.astype(np.int8)
if quant_mode == 1:
x_final = expanded_x.astype(np.float32)
if scale is None:
x_abs = np.abs(x_final)
x_max = np.max(x_abs, axis=-1, keepdims=True)
expanded_scale = x_max / 127
expanded_x = x_final / expanded_scale
expanded_x = np.round(expanded_x).astype(np.int8)
else:
if scale.shape[0] == 1:
x_final = x_final * scale
else:
if drop_pad_mode == 0:
x_final = x_final * scale[sorted_expert_idx[:active_num] -
expert_start]
else:
for i, val in enumerate(sort_row_tmp):
if val != -1:
x_final[i] = x_final[i] * scale[i //
expert_capacity]
x_abs = np.abs(x_final)
x_max = np.max(x_abs, axis=-1, keepdims=True)
expanded_scale = x_max / 127
expanded_x = x_final / expanded_scale
expanded_x = np.round(expanded_x).astype(np.int8)
if x.dtype == np.int8:
expanded_scale = None
if drop_pad_mode == 1:
expanded_x = np.ma.array(expanded_x, mask=expanded_x_mask).filled(0)
expanded_x = expanded_x.reshape(expert_num, expert_capacity, h)
return expanded_x, expanded_row_idx, expert_tokens_count, expanded_scale
def npu_pta(x, expert_idx, scale, offset, active_num, expert_capacity,
expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode, active_expert_range,
row_idx_type):
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = torch.ops._C_ascend.npu_moe_init_routing_custom(
x,
expert_idx,
scale=scale,
offset=offset,
active_num=active_num,
expert_capacity=expert_capacity,
expert_num=expert_num,
drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
quant_mode=quant_mode,
active_expert_range=active_expert_range,
row_idx_type=row_idx_type)
return expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale
def cmp_out_golden(x_golden, x_out, dtype):
if dtype == 'int8':
cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)], x_golden, atol=1)
else:
cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)],
x_golden,
rtol=1e-05,
atol=1e-05)
return np.all(cmp)
def test_moe_npu(x, expert_idx, scale, offset, active_num, expert_capacity,
expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode, active_expert_range,
row_idx_type):
x_npu = x.npu()
expert_idx_npu = expert_idx.npu()
scale_npu = scale.npu() if scale is not None else None
offset_npu = offset.npu() if offset is not None else None
x_numpy = x.numpy()
expert_idx_numpy = expert_idx.numpy()
scale_numpy = scale.numpy() if scale is not None else None
offset_numpy = offset.numpy() if offset is not None else None
expanded_x_golden, expanded_row_idx_golden, expert_token_cumsum_or_count_golden, expanded_scale_golden = moe_init_routing_golden(
x_numpy, expert_idx_numpy, scale_numpy, offset_numpy, active_num,
expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, active_expert_range, quant_mode, row_idx_type)
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = npu_pta(
x_npu, expert_idx_npu, scale_npu, offset_npu, active_num,
expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode, active_expert_range, row_idx_type)
if quant_mode == -1:
expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x,
"float32")
else:
expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x,
"int8")
expanded_row_idx_result = cmp_out_golden(expanded_row_idx_golden,
expanded_row_idx, "int32")
if expert_tokens_num_flag:
expert_tokens_result = cmp_out_golden(
expert_token_cumsum_or_count_golden, expert_token_cumsum_or_count,
"int64")
else:
expert_tokens_result = True
if quant_mode == 1 or (quant_mode == -1 and scale is not None):
expand_scale_result = cmp_out_golden(expanded_scale_golden.flatten(),
expanded_scale, "float32")
else:
expand_scale_result = True
compare_result = expanded_x_result and expanded_row_idx_result and expert_tokens_result and expand_scale_result
# print('=======case result=======: ', compare_result)
return compare_result
def test_moe_init_routing_custom():
failed_test_cnt = 0
drop_pad_mode = [0, 1]
expert_tokens_num_type = [0, 1, 2]
expert_tokens_num_flag = [True, False]
quant_mode = [0, 1, -1]
row_idx_type = [0, 1]
scale_type = [0, 1, 2]
product_result = itertools.product(drop_pad_mode, expert_tokens_num_type,
expert_tokens_num_flag, quant_mode,
row_idx_type, scale_type)
for idx, (drop_pad_mode_, expert_tokens_num_type_, expert_tokens_num_flag_,
quant_mode_, row_idx_type_,
scale_type_) in enumerate(product_result, 5):
expert_num_ = random.randint(2, 500)
expert_start = random.randint(0, expert_num_ - 1)
expert_end = random.randint(expert_start + 1, expert_num_)
active_expert_range_ = [expert_start, expert_end]
N = random.randint(1, 100)
H = random.randint(12, 100)
K = random.randint(1, 12)
x_ = torch.randn(N, H, dtype=torch.float16) * 5
expert_capacity_ = random.randint(1, N - 1) if N > 1 else 1
expert_idx_ = torch.randint(0,
expert_num_ - 1, (N, K),
dtype=torch.int32)
active_num_ = N * K
if drop_pad_mode_ == 1:
active_expert_range_ = [0, expert_num_]
expert_tokens_num_type_ = 1
row_idx_type_ = 0
if quant_mode_ == 0:
scale_ = torch.randn(1, dtype=torch.float)
offset_ = torch.randn(1, dtype=torch.float)
elif quant_mode_ == -1:
scale_ = None
offset_ = None
else:
if scale_type_ == 0:
scale_ = None
offset_ = None
elif scale_type_ == 1:
scale_ = torch.randn(1, H, dtype=torch.float)
offset_ = None
else:
scale_ = torch.randn(active_expert_range_[1] -
active_expert_range_[0],
H,
dtype=torch.float)
offset_ = None
result_pta = test_moe_npu(x_, expert_idx_, scale_, offset_,
active_num_, expert_capacity_, expert_num_,
drop_pad_mode_, expert_tokens_num_type_,
expert_tokens_num_flag_, quant_mode_,
active_expert_range_, row_idx_type_)
if not result_pta:
failed_test_cnt += 1
assert (failed_test_cnt == 0)