add dispatch_gmm_combine kernel (#3532)

### What this PR does / why we need it?

This PR introduces the Ascend implementation of the
`dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime,
together with follow‑up fixes to ensure the kernel builds and runs
correctly in CI.

- Add full host and device implementation of the `dispatch_ffn_combine`
kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE
routing helpers, and kernel utilities for quantized FFN dispatch.
- Integrate the new kernel with the PyTorch binding
(csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend
runtime (vllm_ascend/ascend_forward_context.py,
vllm_ascend/worker/model_runner_v1.py).
- Extend fused MoE communication and token dispatch support in
`vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new
dispatch path.
- Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py
to support the new FFN dispatch flow.
- Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake
configuration, and include/namespace usage in the new kernel files.
- Add an end‑to‑end nightly test
`tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper
utilities in `vllm_ascend/utils.py` to validate the new kernel.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.12.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0

---------

Signed-off-by: mojave2 <chenchen145@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Chen Chen
2025-12-04 23:00:59 +08:00
committed by GitHub
parent 752a55473c
commit ad0607f900
61 changed files with 9795 additions and 53 deletions

View File

@@ -37,59 +37,60 @@
namespace vllm_ascend {
const int64_t INT4_NUMS_IN_INT32 = 8;
void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping, aclrtStream stream) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
aclrtMemcpyKind memcpy_type;
const torch::Tensor& block_mapping, aclrtStream stream)
{
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
aclrtMemcpyKind memcpy_type;
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same npu");
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
} else {
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
}
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same npu");
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
} else {
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
}
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const int64_t num_blocks = block_mapping.size(0);
const int64_t max_src_block = src.size(0);
const int64_t max_dst_block = dst.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
const int64_t num_blocks = block_mapping.size(0);
const int64_t max_src_block = src.size(0);
const int64_t max_dst_block = dst.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
src_ptr + src_offset, block_size_in_bytes,
memcpy_type, stream);
}
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
src_ptr + src_offset, block_size_in_bytes,
memcpy_type, stream);
}
}
void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z)
{
const c10_npu::OptionalNPUGuard npuGuard(
const c10_npu::OptionalNPUGuard npuGuard(
(!x.device().is_cpu()) ? x.device() : y.device()
);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
swap_blocks_impl(x, y, z, stream);
return;
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
swap_blocks_impl(x, y, z, stream);
return;
}
AscendType get_dtype_from_torch(at::ScalarType scalarType)
@@ -617,7 +618,33 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
});
cmd.Run();
return;
}
at::Tensor& dispatch_ffn_combine(
const at::Tensor& x,
const at::Tensor& weight1,
const at::Tensor& weight2,
const at::Tensor& expert_idx,
const at::Tensor& scale1,
const at::Tensor& scale2,
const at::Tensor& probs,
c10::string_view group,
int64_t max_output_size,
at::Tensor& out
) {
char *group_ep_ptr = const_cast<char *>(group.data());
EXEC_NPU_CMD(aclnnDispatchFFNCombine,
x,
weight1,
weight2,
expert_idx,
scale1,
scale2,
probs,
group_ep_ptr,
max_output_size,
out);
return out;
}
at::Tensor npu_lightning_indexer(
@@ -810,4 +837,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
" int sparse_mode=3) -> Tensor"
);
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
ops.def(
"dispatch_ffn_combine(Tensor x, Tensor weight1, Tensor weight2, Tensor expert_idx,"
" Tensor scale1, Tensor scale2, Tensor probs, str group,"
" int max_output_size, Tensor! out) -> Tensor"
);
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
}