add dispatch_gmm_combine kernel (#3532)
### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -37,59 +37,60 @@
|
||||
namespace vllm_ascend {
|
||||
const int64_t INT4_NUMS_IN_INT32 = 8;
|
||||
void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping, aclrtStream stream) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
aclrtMemcpyKind memcpy_type;
|
||||
const torch::Tensor& block_mapping, aclrtStream stream)
|
||||
{
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
aclrtMemcpyKind memcpy_type;
|
||||
|
||||
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same npu");
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
|
||||
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
|
||||
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
|
||||
}
|
||||
if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) {
|
||||
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same npu");
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE;
|
||||
} else if ((!src_device.is_cpu()) && dst_device.is_cpu()) {
|
||||
memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
} else if (src_device.is_cpu() && (!dst_device.is_cpu())) {
|
||||
memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device);
|
||||
}
|
||||
|
||||
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||
|
||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
|
||||
|
||||
const int64_t num_blocks = block_mapping.size(0);
|
||||
const int64_t max_src_block = src.size(0);
|
||||
const int64_t max_dst_block = dst.size(0);
|
||||
for (size_t i = 0; i < num_blocks; i++) {
|
||||
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
|
||||
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
|
||||
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
|
||||
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
|
||||
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
|
||||
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
const int64_t num_blocks = block_mapping.size(0);
|
||||
const int64_t max_src_block = src.size(0);
|
||||
const int64_t max_dst_block = dst.size(0);
|
||||
for (size_t i = 0; i < num_blocks; i++) {
|
||||
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||
TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block,
|
||||
"src block index ", src_block_number, " out of range (max: ", max_src_block, ")");
|
||||
TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block,
|
||||
"dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")");
|
||||
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
|
||||
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
|
||||
src_ptr + src_offset, block_size_in_bytes,
|
||||
memcpy_type, stream);
|
||||
}
|
||||
aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes,
|
||||
src_ptr + src_offset, block_size_in_bytes,
|
||||
memcpy_type, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z)
|
||||
{
|
||||
|
||||
const c10_npu::OptionalNPUGuard npuGuard(
|
||||
const c10_npu::OptionalNPUGuard npuGuard(
|
||||
(!x.device().is_cpu()) ? x.device() : y.device()
|
||||
);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
swap_blocks_impl(x, y, z, stream);
|
||||
return;
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
swap_blocks_impl(x, y, z, stream);
|
||||
return;
|
||||
}
|
||||
|
||||
AscendType get_dtype_from_torch(at::ScalarType scalarType)
|
||||
@@ -617,7 +618,33 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
}
|
||||
|
||||
at::Tensor& dispatch_ffn_combine(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& weight1,
|
||||
const at::Tensor& weight2,
|
||||
const at::Tensor& expert_idx,
|
||||
const at::Tensor& scale1,
|
||||
const at::Tensor& scale2,
|
||||
const at::Tensor& probs,
|
||||
c10::string_view group,
|
||||
int64_t max_output_size,
|
||||
at::Tensor& out
|
||||
) {
|
||||
char *group_ep_ptr = const_cast<char *>(group.data());
|
||||
EXEC_NPU_CMD(aclnnDispatchFFNCombine,
|
||||
x,
|
||||
weight1,
|
||||
weight2,
|
||||
expert_idx,
|
||||
scale1,
|
||||
scale2,
|
||||
probs,
|
||||
group_ep_ptr,
|
||||
max_output_size,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer(
|
||||
@@ -810,4 +837,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
" int sparse_mode=3) -> Tensor"
|
||||
);
|
||||
ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention);
|
||||
|
||||
ops.def(
|
||||
"dispatch_ffn_combine(Tensor x, Tensor weight1, Tensor weight2, Tensor expert_idx,"
|
||||
" Tensor scale1, Tensor scale2, Tensor probs, str group,"
|
||||
" int max_output_size, Tensor! out) -> Tensor"
|
||||
);
|
||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user