From df1ee8070d64df642d17c9f8dbdf9c5a14f5216e Mon Sep 17 00:00:00 2001 From: kx <1670186653@qq.com> Date: Fri, 13 Mar 2026 14:07:35 +0800 Subject: [PATCH] [feat][spec decode]Unified draft parallel (#6766) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Implement a unified parallelized speculative decoding in VLLM Ascend,which can simultaneously support parallel speculative inference schemes such as Pard, P-Eagle, etc. refer to https://github.com/vllm-project/vllm-ascend/pull/6565 and https://github.com/vllm-project/vllm-ascend/pull/4078 ### How was this patch tested? run with parallel drafting script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' base script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 benchmark script: MAX_CONCURRENCY=1 NUM_PROMPTS=80 vllm bench serve --port 8811 \ --temperature 0 \ --model /model/Llama-3.1-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --dataset-path philschmid/mt-bench \ --num-prompts ${NUM_PROMPTS} \ --max-concurrency ${MAX_CONCURRENCY} \ --seed 1234 test results : base(without spec decode): TTFT 79.46ms TPOT 26.99ms output_tokens_throughput 36.75 tok/s this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms output_tokens_throughput 72.98 tok/s per-position acceptance(from position 0 to 7): 79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%. ---------------------------------------------------------------------- run on qwen3 model script : export target=/model/Qwen3-1.7B export draft=/model/PARD-Qwen3-0.6B export CUDA_VISIBLE_DEVICES=1 export ASCEND_RT_VISIBLE_DEVICES=1 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' cc @NickJudyHvv - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 --------- Signed-off-by: 01267596 Signed-off-by: kx <1670186653@qq.com> Signed-off-by: HF-001 <1670186653@qq.com> Co-authored-by: 01267596 --- csrc/build_aclnn.sh | 3 +- .../op_host/CMakeLists.txt | 22 + .../copy_and_expand_eagle_inputs_def.cpp | 87 ++++ ...opy_and_expand_eagle_inputs_infershape.cpp | 107 ++++ .../copy_and_expand_eagle_inputs_tiling.cpp | 121 +++++ .../copy_and_expand_eagle_inputs_tiling.h | 37 ++ .../copy_and_expand_eagle_inputs.cpp | 386 ++++++++++++++ csrc/torch_binding.cpp | 45 ++ csrc/torch_binding_meta.cpp | 29 ++ .../test_copy_and_expand_eagle_inputs.py | 471 ++++++++++++++++++ .../spec_decode/test_v1_spec_decode.py | 284 +++++++---- tests/ut/spec_decode/test_eagle_proposer.py | 178 +++---- vllm_ascend/attention/attention_v1.py | 3 + vllm_ascend/ops/triton/spec_decode/utils.py | 2 + vllm_ascend/spec_decode/__init__.py | 4 + vllm_ascend/spec_decode/draft_proposer.py | 71 +++ vllm_ascend/spec_decode/eagle_proposer.py | 373 ++++++++++---- vllm_ascend/worker/model_runner_v1.py | 31 +- 18 files changed, 1943 insertions(+), 311 deletions(-) create mode 100644 csrc/copy_and_expand_eagle_inputs/op_host/CMakeLists.txt create mode 100644 csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_def.cpp create mode 100644 csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_infershape.cpp create mode 100644 csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.cpp create mode 100644 csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.h create mode 100644 csrc/copy_and_expand_eagle_inputs/op_kernel/copy_and_expand_eagle_inputs.cpp create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/test_copy_and_expand_eagle_inputs.py create mode 100644 vllm_ascend/spec_decode/draft_proposer.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index d3ad883a..5b11cbe2 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;causal_conv1d;" + CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -64,6 +64,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "add_rms_norm_bias" "apply_top_k_top_p_custom" "transpose_kv_cache_by_block" + "copy_and_expand_eagle_inputs" "causal_conv1d" "moe_grouped_matmul" ) diff --git a/csrc/copy_and_expand_eagle_inputs/op_host/CMakeLists.txt b/csrc/copy_and_expand_eagle_inputs/op_host/CMakeLists.txt new file mode 100644 index 00000000..89e288b6 --- /dev/null +++ b/csrc/copy_and_expand_eagle_inputs/op_host/CMakeLists.txt @@ -0,0 +1,22 @@ +add_ops_compile_options( + OP_NAME CopyAndExpandEagleInputs + OPTIONS --cce-auto-sync=on + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnn PRIVATE + copy_and_expand_eagle_inputs_def.cpp +) + +target_sources(optiling PRIVATE + copy_and_expand_eagle_inputs_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + copy_and_expand_eagle_inputs_infershape.cpp +) diff --git a/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_def.cpp b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_def.cpp new file mode 100644 index 00000000..c33301dd --- /dev/null +++ b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_def.cpp @@ -0,0 +1,87 @@ +/** + * @file copy_and_expand_eagle_inputs_def.cpp + * @brief CopyAndExpandEagleInputs OpDef registration + */ + +#include "register/op_def_registry.h" + +namespace ops { + +class CopyAndExpandEagleInputs : public OpDef { +public: + explicit CopyAndExpandEagleInputs(const char* name) : OpDef(name) + { + // -------------------- Inputs -------------------- + this->Input("target_token_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("target_positions") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("next_token_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("query_start_loc") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("query_end_loc") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + // -------------------- Outputs -------------------- + this->Output("out_input_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("out_positions") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("out_is_rejected_token_mask") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("out_is_masked_token_mask") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("out_new_token_indices") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("out_hidden_state_mapping") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + // -------------------- Attributes -------------------- + this->Attr("padding_token_id").Int(); + this->Attr("parallel_drafting_token_id").Int(); + this->Attr("num_padding_slots_per_request").Int(); + this->Attr("shift_input_ids").Bool(); + this->Attr("total_input_tokens").Int(); + + // -------------------- Platform -------------------- + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(CopyAndExpandEagleInputs); + +} // namespace ops diff --git a/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_infershape.cpp b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_infershape.cpp new file mode 100644 index 00000000..3fc3d383 --- /dev/null +++ b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_infershape.cpp @@ -0,0 +1,107 @@ +/** + * @file copy_and_expand_eagle_inputs_infershape.cpp + * @brief InferShape and InferDataType for CopyAndExpandEagleInputs + */ + +#include "register/op_def_registry.h" +#include "log/ops_log.h" + +#define unlikely(x) __builtin_expect((x), 0) +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if (unlikely((ptr) == nullptr)) { \ + const char* name = (unlikely(((context) == nullptr) || (context)->GetNodeName() == nullptr)) ? \ + "nil" : \ + (context)->GetNodeName(); \ + OPS_LOG_E(name, "%s is nullptr!", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +static constexpr int IDX_TARGET_TOKEN_IDS = 0; +static constexpr int IDX_TARGET_POSITIONS = 1; +static constexpr int IDX_NEXT_TOKEN_IDS = 2; +static constexpr int IDX_QUERY_START_LOC = 3; +static constexpr int IDX_QUERY_END_LOC = 4; + +static constexpr int OUT_INPUT_IDS = 0; +static constexpr int OUT_POSITIONS = 1; +static constexpr int OUT_REJECTED_MASK = 2; +static constexpr int OUT_MASKED_MASK = 3; +static constexpr int OUT_NEW_TOKEN_INDICES = 4; +static constexpr int OUT_HIDDEN_STATE_MAPPING = 5; +static constexpr int OUTPUT_NUM = 6; + +static constexpr int ATTR_NUM_PADDING_SLOTS = 2; +static constexpr int ATTR_TOTAL_INPUT_TOKENS = 4; + +using namespace ge; + +namespace ops { + +static ge::graphStatus InferShape4CopyAndExpandEagleInputs(gert::InferShapeContext* context) +{ + // Get input shapes + const gert::Shape* targetTokenIdsShape = context->GetInputShape(IDX_TARGET_TOKEN_IDS); + OP_CHECK_NULL_WITH_CONTEXT(context, targetTokenIdsShape); + const gert::Shape* queryStartLocShape = context->GetInputShape(IDX_QUERY_START_LOC); + OP_CHECK_NULL_WITH_CONTEXT(context, queryStartLocShape); + + // Derive dimensions from input shapes + int64_t totalInputTokens = targetTokenIdsShape->GetDim(0); + int64_t numReqs = queryStartLocShape->GetDim(0) - 1; + + // Get num_padding_slots_per_request from attributes + auto attrs = context->GetAttrs(); + OP_CHECK_NULL_WITH_CONTEXT(context, attrs); + int64_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer(ATTR_NUM_PADDING_SLOTS)); + + // Compute total_draft_tokens = total_input_tokens + (num_padding_slots_per_request - 1) * num_reqs + int64_t totalDraftTokens = totalInputTokens + (numPaddingSlotsPerReq - 1) * numReqs; + + // Get and validate all output shapes + gert::Shape* outShapes[OUTPUT_NUM]; + for (int i = 0; i < OUTPUT_NUM; ++i) { + outShapes[i] = context->GetOutputShape(i); + OP_CHECK_NULL_WITH_CONTEXT(context, outShapes[i]); + outShapes[i]->SetDimNum(1); + } + + // out_input_ids, out_positions, out_rejected_mask, out_masked_mask: [total_draft_tokens] + outShapes[OUT_INPUT_IDS]->SetDim(0, totalDraftTokens); + outShapes[OUT_POSITIONS]->SetDim(0, totalDraftTokens); + outShapes[OUT_REJECTED_MASK]->SetDim(0, totalDraftTokens); + outShapes[OUT_MASKED_MASK]->SetDim(0, totalDraftTokens); + + // out_new_token_indices: [num_reqs * num_padding_slots_per_request] + outShapes[OUT_NEW_TOKEN_INDICES]->SetDim(0, numReqs * numPaddingSlotsPerReq); + + // out_hidden_state_mapping: [total_input_tokens] + outShapes[OUT_HIDDEN_STATE_MAPPING]->SetDim(0, totalInputTokens); + + return GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataType4CopyAndExpandEagleInputs(gert::InferDataTypeContext* context) +{ + // out_input_ids: INT32 + context->SetOutputDataType(OUT_INPUT_IDS, DT_INT32); + // out_positions: INT32 + context->SetOutputDataType(OUT_POSITIONS, DT_INT32); + // out_is_rejected_token_mask: INT8 + context->SetOutputDataType(OUT_REJECTED_MASK, DT_INT8); + // out_is_masked_token_mask: INT8 + context->SetOutputDataType(OUT_MASKED_MASK, DT_INT8); + // out_new_token_indices: INT32 + context->SetOutputDataType(OUT_NEW_TOKEN_INDICES, DT_INT32); + // out_hidden_state_mapping: INT32 + context->SetOutputDataType(OUT_HIDDEN_STATE_MAPPING, DT_INT32); + + return GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(CopyAndExpandEagleInputs) + .InferShape(InferShape4CopyAndExpandEagleInputs) + .InferDataType(InferDataType4CopyAndExpandEagleInputs); + +} // namespace ops diff --git a/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.cpp b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.cpp new file mode 100644 index 00000000..ee9962b6 --- /dev/null +++ b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.cpp @@ -0,0 +1,121 @@ +/** + * @file copy_and_expand_eagle_inputs_tiling.cpp + * @brief CopyAndExpandEagleInputs TilingFunc implementation + */ + +#include "copy_and_expand_eagle_inputs_tiling.h" +#include "register/op_def_registry.h" +#include "log/ops_log.h" + +#include + +namespace optiling { + +static void GetCompileParameters( + gert::TilingContext* context, uint32_t& coreNum) +{ + auto ptrCompileInfo = reinterpret_cast(context->GetCompileInfo()); + if (ptrCompileInfo == nullptr) { + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + coreNum = ascendcPlatform.GetCoreNum(); + } else { + coreNum = ptrCompileInfo->totalCoreNum; + } +} + +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + OPS_LOG_I(context, "Enter TilingFunc for CopyAndExpandEagleInputs"); + OPS_LOG_D(context, "TilingFunc running."); + + // ========== 1. Get hardware core count ========== + uint32_t coreNum; + GetCompileParameters(context, coreNum); + + // ========== 2. Derive num_reqs from query_start_loc shape ========== + // query_start_loc is the 4th input (index 3), shape [num_reqs + 1] + auto queryStartLocShape = context->GetInputShape(3); + uint32_t numReqs = 0; + if (queryStartLocShape != nullptr && + queryStartLocShape->GetStorageShape().GetDimNum() > 0) { + int64_t dim0 = queryStartLocShape->GetStorageShape().GetDim(0); + numReqs = (dim0 > 1) ? static_cast(dim0 - 1) : 0; + } + + // ========== 3. Get operator attributes ========== + auto attrs = context->GetAttrs(); + + int32_t paddingTokenId = *(attrs->GetAttrPointer(0)); + int32_t parallelDraftingTokenId = *(attrs->GetAttrPointer(1)); + int32_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer(2)); + bool shiftInputIds = *(attrs->GetAttrPointer(3)); + int32_t totalInputTokens = *(attrs->GetAttrPointer(4)); + + // ========== 4. Compute core distribution ========== + uint32_t usedCoreNum = std::min(coreNum, numReqs); + if (usedCoreNum == 0) { + usedCoreNum = 1; + } + uint32_t reqsPerCore = numReqs / usedCoreNum; + uint32_t remainderReqs = numReqs % usedCoreNum; + + // ========== 5. Set tiling_key ========== + context->SetTilingKey(1); + + // ========== 6. Get output shape ========== + uint32_t totalDraftTokens = 0; + auto outShape = context->GetOutputShape(0); + if (outShape != nullptr && + outShape->GetStorageShape().GetDimNum() > 0) { + totalDraftTokens = static_cast(outShape->GetStorageShape().GetDim(0)); + } + + // ========== 7. Fill TilingData ========== + CopyAndExpandEagleInputsTilingData tiling; + tiling.set_usedCoreNum(usedCoreNum); + tiling.set_numReqs(numReqs); + tiling.set_reqsPerCore(reqsPerCore); + tiling.set_remainderReqs(remainderReqs); + tiling.set_paddingTokenId(paddingTokenId); + tiling.set_parallelDraftingTokenId(parallelDraftingTokenId); + tiling.set_numPaddingSlotsPerReq(static_cast(numPaddingSlotsPerReq)); + tiling.set_totalInputTokens(static_cast(totalInputTokens)); + tiling.set_shiftInputIds(shiftInputIds ? 1u : 0u); + tiling.set_totalDraftTokens(totalDraftTokens); + + tiling.SaveToBuffer( + context->GetRawTilingData()->GetData(), + context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + // ========== 8. Set block_dim ========== + context->SetBlockDim(usedCoreNum); + + OPS_LOG_I(context, "Block Dim: %u", usedCoreNum); + OPS_LOG_I(context, + "numReqs: %u, reqsPerCore: %u, remainderReqs: %u, totalInputTokens: %d, totalDraftTokens: %u", + numReqs, reqsPerCore, remainderReqs, totalInputTokens, totalDraftTokens); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingPrepare4CopyAndExpandEagleInputs(gert::TilingParseContext* context) +{ + OPS_LOG_D(context, "TilingPrepare4CopyAndExpandEagleInputs running."); + OPS_LOG_I(context, "TilingPrepare4CopyAndExpandEagleInputs running."); + auto compileInfo = context->GetCompiledInfo(); + OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo); + auto platformInfo = context->GetPlatformInfo(); + OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + + compileInfo->totalCoreNum = ascendcPlatform.GetCoreNum(); + + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(CopyAndExpandEagleInputs) + .Tiling(TilingFunc) + .TilingParse(TilingPrepare4CopyAndExpandEagleInputs); + +} // namespace optiling diff --git a/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.h b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.h new file mode 100644 index 00000000..f987fa15 --- /dev/null +++ b/csrc/copy_and_expand_eagle_inputs/op_host/copy_and_expand_eagle_inputs_tiling.h @@ -0,0 +1,37 @@ +#ifndef COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H +#define COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H + +#include "register/tilingdata_base.h" +#include "error_log.h" +#include "register/op_impl_registry.h" +#include "tiling/platform/platform_ascendc.h" + +namespace optiling { + +BEGIN_TILING_DATA_DEF(CopyAndExpandEagleInputsTilingData) + // ---- 分核参数 ---- + TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum); // 实际使用的核数 + TILING_DATA_FIELD_DEF(uint32_t, numReqs); // 总请求数 + TILING_DATA_FIELD_DEF(uint32_t, reqsPerCore); // 每核基础请求数 + TILING_DATA_FIELD_DEF(uint32_t, remainderReqs); // 余数(前 remainder 个核多处理 1 个请求) + + // ---- 算子属性 ---- + TILING_DATA_FIELD_DEF(int32_t, paddingTokenId); // 填充 token ID + TILING_DATA_FIELD_DEF(int32_t, parallelDraftingTokenId); // 并行推测解码 token ID + TILING_DATA_FIELD_DEF(uint32_t, numPaddingSlotsPerReq); // 每个请求的 padding 槽位数 + TILING_DATA_FIELD_DEF(uint32_t, totalInputTokens); // 输入 token 总数(用于 clamp) + TILING_DATA_FIELD_DEF(uint32_t, shiftInputIds); // 0 = false, 1 = true + + // ---- 输出尺寸 ---- + TILING_DATA_FIELD_DEF(uint32_t, totalDraftTokens); // 输出 token 总数 +END_TILING_DATA_DEF; + +struct CopyAndExpandEagleInputsCompileInfo { + uint32_t totalCoreNum = 0; +}; + +REGISTER_TILING_DATA_CLASS(CopyAndExpandEagleInputs, CopyAndExpandEagleInputsTilingData) + +} // namespace optiling + +#endif // COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H diff --git a/csrc/copy_and_expand_eagle_inputs/op_kernel/copy_and_expand_eagle_inputs.cpp b/csrc/copy_and_expand_eagle_inputs/op_kernel/copy_and_expand_eagle_inputs.cpp new file mode 100644 index 00000000..98029a02 --- /dev/null +++ b/csrc/copy_and_expand_eagle_inputs/op_kernel/copy_and_expand_eagle_inputs.cpp @@ -0,0 +1,386 @@ +/** + * CopyAndExpandEagleInputs 算子 Kernel 实现 (DataCopy 版) + * + * 多核策略: + * 所有 GM 读写通过 DataCopy 完成(不使用 GlobalTensor::SetValue/GetValue 访问 GM)。 + * UB (LocalTensor) 上使用 SetValue/GetValue 构建数据,再 DataCopy 到 GM。 + * 对齐处理参考 CANN 内置算子的 DataCopyCustom 模式。 + */ + +#include "kernel_operator.h" + +using namespace AscendC; + +// ONE_BLK_SIZE comes from AscendC namespace (32 bytes per block) + +class CopyAndExpandEagleInputsKernel { +public: + __aicore__ inline CopyAndExpandEagleInputsKernel() {} + + __aicore__ inline void Init(GM_ADDR targetTokenIds, GM_ADDR targetPositions, + GM_ADDR nextTokenIds, GM_ADDR queryStartLoc, + GM_ADDR queryEndLoc, + GM_ADDR outInputIds, GM_ADDR outPositions, + GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask, + GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping, + const CopyAndExpandEagleInputsTilingData* tilingData) + { + usedCoreNum = tilingData->usedCoreNum; + numReqs = tilingData->numReqs; + reqsPerCore = tilingData->reqsPerCore; + remainderReqs = tilingData->remainderReqs; + paddingTokenId = tilingData->paddingTokenId; + parallelDraftingTokenId = tilingData->parallelDraftingTokenId; + numPaddingSlotsPerReq = tilingData->numPaddingSlotsPerReq; + totalInputTokens = tilingData->totalInputTokens; + totalDraftTokens = tilingData->totalDraftTokens; + + uint32_t coreId = GetBlockIdx(); + if (coreId < remainderReqs) { + myStartReq = coreId * (reqsPerCore + 1); + myNumReqs = reqsPerCore + 1; + } else { + myStartReq = remainderReqs * (reqsPerCore + 1) + (coreId - remainderReqs) * reqsPerCore; + myNumReqs = reqsPerCore; + } + + // 绑定 GM Tensor + gmTargetTokenIds.SetGlobalBuffer((__gm__ int32_t*)targetTokenIds, totalInputTokens); + gmTargetPositions.SetGlobalBuffer((__gm__ int32_t*)targetPositions, totalInputTokens); + gmNextTokenIds.SetGlobalBuffer((__gm__ int32_t*)nextTokenIds, numReqs); + gmQueryStartLoc.SetGlobalBuffer((__gm__ int32_t*)queryStartLoc, numReqs + 1); + gmQueryEndLoc.SetGlobalBuffer((__gm__ int32_t*)queryEndLoc, numReqs); + gmOutInputIds.SetGlobalBuffer((__gm__ int32_t*)outInputIds, totalDraftTokens); + gmOutPositions.SetGlobalBuffer((__gm__ int32_t*)outPositions, totalDraftTokens); + gmOutIsRejectedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsRejectedTokenMask, totalDraftTokens); + gmOutIsMaskedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsMaskedTokenMask, totalDraftTokens); + gmOutNewTokenIndices.SetGlobalBuffer((__gm__ int32_t*)outNewTokenIndices, numPaddingSlotsPerReq * numReqs); + gmOutHiddenStateMapping.SetGlobalBuffer((__gm__ int32_t*)outHiddenStateMapping, totalInputTokens); + + // 分配 UB 缓冲区 —— 每个 TBuf 的基地址自动 32 字节对齐 + // 元数据各自独立 TBuf,避免 UB 地址不对齐 + uint32_t metaAligned = AlignUp((myNumReqs + 1) * sizeof(int32_t), ONE_BLK_SIZE); + pipe.InitBuffer(qsBuf, metaAligned); + pipe.InitBuffer(qeBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE)); + pipe.InitBuffer(ntBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE)); + + // I/O 缓冲区 + constexpr uint32_t MAX_PER_REQ = 4096; + pipe.InitBuffer(inputBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE)); + pipe.InitBuffer(outIdsBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE)); + pipe.InitBuffer(outPosBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE)); + pipe.InitBuffer(outRejBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE)); + pipe.InitBuffer(outMskBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE)); + pipe.InitBuffer(ntiBuf, AlignUp(64 * sizeof(int32_t), ONE_BLK_SIZE)); + pipe.InitBuffer(hsmBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE)); + + // DataCopy 元数据到各自 UB + if (myNumReqs > 0) { + LocalTensor lqs = qsBuf.Get(); + DataCopyIn(lqs, gmQueryStartLoc, (int32_t)myStartReq, (int32_t)(myNumReqs + 1)); + + LocalTensor lqe = qeBuf.Get(); + DataCopyIn(lqe, gmQueryEndLoc, (int32_t)myStartReq, (int32_t)myNumReqs); + + LocalTensor lnt = ntBuf.Get(); + DataCopyIn(lnt, gmNextTokenIds, (int32_t)myStartReq, (int32_t)myNumReqs); + } + } + + __aicore__ inline void ProcessShiftFalse() + { + for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) { + ProcessOneRequestShiftFalse(myStartReq + rLocal, rLocal); + } + } + + __aicore__ inline void ProcessShiftTrue() + { + for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) { + ProcessOneRequestShiftTrue(myStartReq + rLocal, rLocal); + } + } + +private: + // ============================================================ + // AlignUp 辅助 + // ============================================================ + static __aicore__ inline uint32_t AlignUp(uint32_t x, uint32_t a) + { + return (x + a - 1) / a * a; + } + + // ============================================================ + // GM → UB: 标准 DataCopy,count 自动 round-up 到 block 对齐 + // 多读的元素在 UB 中不会被使用,安全无害 + // ============================================================ + __aicore__ inline void DataCopyIn(LocalTensor& dst, + GlobalTensor& src, + int32_t gmOffset, int32_t count) + { + if (count <= 0) return; + constexpr int32_t ELEMS_PER_BLK = ONE_BLK_SIZE / (int32_t)sizeof(int32_t); // 8 + int32_t aligned = (count + ELEMS_PER_BLK - 1) / ELEMS_PER_BLK * ELEMS_PER_BLK; + DataCopy(dst, src[gmOffset], aligned); + pipe_barrier(PIPE_ALL); + } + + // ============================================================ + // UB → GM: DataCopyPad + DataCopyExtParams(C220 支持任意字节数) + // 精确写入 count 个元素,不越界覆盖相邻数据 + // ============================================================ + __aicore__ inline void DataCopyOut_int32(GlobalTensor& dst, + LocalTensor& src, + int32_t gmOffset, int32_t count) + { + if (count <= 0) return; + uint32_t totalBytes = static_cast(count) * static_cast(sizeof(int32_t)); + pipe_barrier(PIPE_ALL); + DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0)); + pipe_barrier(PIPE_ALL); + } + + __aicore__ inline void DataCopyOut_int8(GlobalTensor& dst, + LocalTensor& src, + int32_t gmOffset, int32_t count) + { + if (count <= 0) return; + uint32_t totalBytes = static_cast(count) * static_cast(sizeof(int8_t)); + pipe_barrier(PIPE_ALL); + DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0)); + pipe_barrier(PIPE_ALL); + } + + // ============================================================ + // 元数据读取 (从各自 UB 缓冲区) + // ============================================================ + __aicore__ inline int32_t ReadQS(uint32_t rLocal) { + return qsBuf.Get().GetValue(rLocal); + } + __aicore__ inline int32_t ReadNextQS(uint32_t rLocal) { + return qsBuf.Get().GetValue(rLocal + 1); + } + __aicore__ inline int32_t ReadQE(uint32_t rLocal) { + return qeBuf.Get().GetValue(rLocal); + } + __aicore__ inline int32_t ReadNT(uint32_t rLocal) { + return ntBuf.Get().GetValue(rLocal); + } + + // ============================================================ + // shift_input_ids = false + // ============================================================ + __aicore__ inline void ProcessOneRequestShiftFalse(uint32_t r, uint32_t rLocal) + { + int32_t queryStart = ReadQS(rLocal); + int32_t nextQueryStart = ReadNextQS(rLocal); + int32_t queryEnd = ReadQE(rLocal); + + int32_t numRejected = nextQueryStart - queryEnd - 1; + if (numRejected < 0) numRejected = 0; + int32_t numValid = queryEnd - queryStart + 1; + if (numValid < 0) numValid = 0; + + int32_t outputStart = queryStart + (int32_t)r * (int32_t)numPaddingSlotsPerReq; + int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected; + + // 读取输入 token 到 UB + int32_t numInputTokensForReq = nextQueryStart - queryStart; + LocalTensor localInput = inputBuf.Get(); + if (numInputTokensForReq > 0) { + DataCopyIn(localInput, gmTargetTokenIds, queryStart, numInputTokensForReq); + } + + // 读取起始 position + LocalTensor localTmpPos = hsmBuf.Get(); + DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1); + int32_t startPos = localTmpPos.GetValue(0); + + int32_t nextTokenId = ReadNT(rLocal); + + // 构建输出到 UB + LocalTensor lIds = outIdsBuf.Get(); + LocalTensor lPos = outPosBuf.Get(); + LocalTensor lRej = outRejBuf.Get(); + LocalTensor lMsk = outMskBuf.Get(); + + for (int32_t j = 0; j < numValid; j++) { + int32_t inIdx = j; + if (inIdx >= numInputTokensForReq) inIdx = numInputTokensForReq - 1; + lIds.SetValue(j, localInput.GetValue(inIdx)); + lPos.SetValue(j, startPos + j); + lRej.SetValue(j, (int8_t)0); + lMsk.SetValue(j, (int8_t)0); + } + // Bonus + lIds.SetValue(numValid, nextTokenId); + lPos.SetValue(numValid, startPos + numValid); + lRej.SetValue(numValid, (int8_t)0); + lMsk.SetValue(numValid, (int8_t)0); + // Parallel Draft + for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) { + int32_t j = numValid + k; + lIds.SetValue(j, parallelDraftingTokenId); + lPos.SetValue(j, startPos + j); + lRej.SetValue(j, (int8_t)0); + lMsk.SetValue(j, (int8_t)1); + } + // Rejected + for (int32_t k = 0; k < numRejected; k++) { + int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k; + lIds.SetValue(j, paddingTokenId); + lPos.SetValue(j, (int32_t)0); + lRej.SetValue(j, (int8_t)1); + lMsk.SetValue(j, (int8_t)0); + } + + // UB → GM + DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen); + DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen); + DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen); + DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen); + + // NTI + LocalTensor lNti = ntiBuf.Get(); + lNti.SetValue(0, outputStart + numValid); + for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) { + lNti.SetValue(k, outputStart + numValid + k); + } + int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq; + DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq); + } + + // ============================================================ + // shift_input_ids = true + // ============================================================ + __aicore__ inline void ProcessOneRequestShiftTrue(uint32_t r, uint32_t rLocal) + { + int32_t queryStart = ReadQS(rLocal); + int32_t nextQueryStart = ReadNextQS(rLocal); + int32_t queryEnd = ReadQE(rLocal); + + int32_t numRejected = nextQueryStart - queryEnd - 1; + if (numRejected < 0) numRejected = 0; + int32_t numValid = queryEnd - queryStart; + if (numValid < 0) numValid = 0; + + int32_t outputStart = queryStart + (int32_t)r * ((int32_t)numPaddingSlotsPerReq - 1); + int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected; + + int32_t numInputTokensForReq = nextQueryStart - queryStart; + LocalTensor localInput = inputBuf.Get(); + int32_t readStart = queryStart + 1; + int32_t readCount = numValid; + if (readStart + readCount > (int32_t)totalInputTokens) { + readCount = (int32_t)totalInputTokens - readStart; + if (readCount < 0) readCount = 0; + } + if (readCount > 0) { + DataCopyIn(localInput, gmTargetTokenIds, readStart, readCount); + } + + LocalTensor localTmpPos = hsmBuf.Get(); + DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1); + int32_t startPos = localTmpPos.GetValue(0); + + int32_t nextTokenId = ReadNT(rLocal); + + LocalTensor lIds = outIdsBuf.Get(); + LocalTensor lPos = outPosBuf.Get(); + LocalTensor lRej = outRejBuf.Get(); + LocalTensor lMsk = outMskBuf.Get(); + + for (int32_t j = 0; j < numValid; j++) { + int32_t inIdx = j; + if (inIdx >= readCount && readCount > 0) inIdx = readCount - 1; + lIds.SetValue(j, readCount > 0 ? localInput.GetValue(inIdx) : (int32_t)0); + lPos.SetValue(j, startPos + j); + lRej.SetValue(j, (int8_t)0); + lMsk.SetValue(j, (int8_t)0); + } + lIds.SetValue(numValid, nextTokenId); + lPos.SetValue(numValid, startPos + numValid); + lRej.SetValue(numValid, (int8_t)0); + lMsk.SetValue(numValid, (int8_t)0); + for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) { + int32_t j = numValid + k; + lIds.SetValue(j, parallelDraftingTokenId); + lPos.SetValue(j, startPos + j); + lRej.SetValue(j, (int8_t)0); + lMsk.SetValue(j, (int8_t)1); + } + for (int32_t k = 0; k < numRejected; k++) { + int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k; + lIds.SetValue(j, paddingTokenId); + lPos.SetValue(j, (int32_t)0); + lRej.SetValue(j, (int8_t)1); + lMsk.SetValue(j, (int8_t)0); + } + + DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen); + DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen); + DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen); + DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen); + + LocalTensor lNti = ntiBuf.Get(); + lNti.SetValue(0, outputStart + numValid); + for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) { + lNti.SetValue(k, outputStart + numValid + k); + } + int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq; + DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq); + + // hidden_state_mapping + LocalTensor lHsm = hsmBuf.Get(); + for (int32_t j = 0; j < numInputTokensForReq; j++) { + lHsm.SetValue(j, outputStart + j); + } + DataCopyOut_int32(gmOutHiddenStateMapping, lHsm, queryStart, numInputTokensForReq); + } + +private: + GlobalTensor gmTargetTokenIds, gmTargetPositions, gmNextTokenIds; + GlobalTensor gmQueryStartLoc, gmQueryEndLoc; + GlobalTensor gmOutInputIds, gmOutPositions; + GlobalTensor gmOutIsRejectedTokenMask, gmOutIsMaskedTokenMask; + GlobalTensor gmOutNewTokenIndices, gmOutHiddenStateMapping; + + uint32_t usedCoreNum, numReqs, reqsPerCore, remainderReqs; + int32_t paddingTokenId, parallelDraftingTokenId; + uint32_t numPaddingSlotsPerReq, totalInputTokens, totalDraftTokens; + uint32_t myStartReq, myNumReqs; + + TPipe pipe; + TBuf qsBuf, qeBuf, ntBuf; + TBuf inputBuf, outIdsBuf, outPosBuf; + TBuf outRejBuf, outMskBuf, ntiBuf, hsmBuf; +}; + +extern "C" __global__ __aicore__ void copy_and_expand_eagle_inputs( + GM_ADDR targetTokenIds, GM_ADDR targetPositions, + GM_ADDR nextTokenIds, GM_ADDR queryStartLoc, + GM_ADDR queryEndLoc, + GM_ADDR outInputIds, GM_ADDR outPositions, + GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask, + GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping, + GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + + if (GetBlockIdx() >= tilingData.usedCoreNum) { + return; + } + + if (TILING_KEY_IS(1)) { + CopyAndExpandEagleInputsKernel op; + op.Init(targetTokenIds, targetPositions, nextTokenIds, queryStartLoc, queryEndLoc, + outInputIds, outPositions, outIsRejectedTokenMask, outIsMaskedTokenMask, + outNewTokenIndices, outHiddenStateMapping, &tilingData); + + if (tilingData.shiftInputIds == 0) { + op.ProcessShiftFalse(); + } else { + op.ProcessShiftTrue(); + } + } +} diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index af311c9c..bc2bb72c 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -597,6 +597,41 @@ void transpose_kv_cache_by_block( } +std::tuple +npu_copy_and_expand_eagle_inputs( + const at::Tensor &target_token_ids, + const at::Tensor &target_positions, + const at::Tensor &next_token_ids, + const at::Tensor &query_start_loc, + const at::Tensor &query_end_loc, + int64_t padding_token_id, + int64_t parallel_drafting_token_id, + int64_t num_padding_slots_per_request, + bool shift_input_ids, + int64_t total_draft_tokens) +{ + int64_t total_input_tokens = target_token_ids.size(0); + int64_t num_reqs = query_start_loc.size(0) - 1; + + auto device = target_token_ids.device(); + at::Tensor out_input_ids = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device)); + at::Tensor out_positions = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device)); + at::Tensor out_is_rejected_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device)); + at::Tensor out_is_masked_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device)); + at::Tensor out_new_token_indices = at::empty({num_reqs * num_padding_slots_per_request}, at::dtype(at::kInt).device(device)); + at::Tensor out_hidden_state_mapping = at::empty({total_input_tokens}, at::dtype(at::kInt).device(device)); + + EXEC_NPU_CMD(aclnnCopyAndExpandEagleInputs, + target_token_ids, target_positions, next_token_ids, query_start_loc, query_end_loc, + padding_token_id, parallel_drafting_token_id, num_padding_slots_per_request, + shift_input_ids, total_input_tokens, + out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask, + out_new_token_indices, out_hidden_state_mapping); + + return {out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask, + out_new_token_indices, out_hidden_state_mapping}; +} + at::Tensor causal_conv1d_fn( const at::Tensor& mixed_qkv_non_spec_T, const at::Tensor& conv_weights, @@ -849,6 +884,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()" ); ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block); + + ops.def( + "npu_copy_and_expand_eagle_inputs(Tensor target_token_ids, Tensor target_positions, " + "Tensor next_token_ids, Tensor query_start_loc, Tensor query_end_loc, " + "int padding_token_id, int parallel_drafting_token_id, int num_padding_slots_per_request, " + "bool shift_input_ids, int total_draft_tokens) -> " + "(Tensor out_input_ids, Tensor out_positions, Tensor out_is_rejected_token_mask, " + "Tensor out_is_masked_token_mask, Tensor out_new_token_indices, Tensor out_hidden_state_mapping)" + ); + ops.impl("npu_copy_and_expand_eagle_inputs", torch::kPrivateUse1, &vllm_ascend::npu_copy_and_expand_eagle_inputs); // causal_conv1d_fn ops.def( "causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, " diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index a5ed22ea..f5980a01 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -458,6 +458,33 @@ void transpose_kv_cache_by_block_meta( return; } +std::tuple +npu_copy_and_expand_eagle_inputs_meta( + const at::Tensor &target_token_ids, + const at::Tensor &target_positions, + const at::Tensor &next_token_ids, + const at::Tensor &query_start_loc, + const at::Tensor &query_end_loc, + int64_t padding_token_id, + int64_t parallel_drafting_token_id, + int64_t num_padding_slots_per_request, + bool shift_input_ids, + int64_t total_draft_tokens) +{ + int64_t total_input_tokens = target_token_ids.size(0); + int64_t num_reqs = query_start_loc.size(0) - 1; + + at::Tensor out_input_ids = at::empty({total_draft_tokens}, target_token_ids.options()); + at::Tensor out_positions = at::empty({total_draft_tokens}, target_token_ids.options()); + at::Tensor out_is_rejected_token_mask = at::empty({total_draft_tokens}, target_token_ids.options().dtype(at::kChar)); + at::Tensor out_is_masked_token_mask = at::empty({total_draft_tokens}, target_token_ids.options().dtype(at::kChar)); + at::Tensor out_new_token_indices = at::empty({num_reqs * num_padding_slots_per_request}, target_token_ids.options()); + at::Tensor out_hidden_state_mapping = at::empty({total_input_tokens}, target_token_ids.options()); + + return {out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask, + out_new_token_indices, out_hidden_state_mapping}; +} + at::Tensor causal_conv1d_fn_meta( const at::Tensor& mixed_qkv_non_spec_T, const at::Tensor& conv_weights, @@ -543,6 +570,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta); // transpose_kv_cache_by_block ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta); + // CopyAndExpandEagleInputs + ops.impl("npu_copy_and_expand_eagle_inputs", &vllm_ascend::meta::npu_copy_and_expand_eagle_inputs_meta); // causal_conv1d_fn ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta); // moe_grouped_matmul diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_copy_and_expand_eagle_inputs.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_copy_and_expand_eagle_inputs.py new file mode 100644 index 00000000..4e65b876 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_copy_and_expand_eagle_inputs.py @@ -0,0 +1,471 @@ +"""E2E accuracy test for CopyAndExpandEagleInputs custom operator. + +Tests the Ascend C kernel against a CPU golden reference implementation +with parametrized test cases covering various configurations. +""" + +import numpy as np +import pytest +import torch + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +SEED = 42 + + +# --------------------------------------------------------------------------- +# Golden reference (CPU, pure Python/NumPy) +# --------------------------------------------------------------------------- + +def golden_copy_and_expand( + target_token_ids: np.ndarray, + target_positions: np.ndarray, + next_token_ids: np.ndarray, + query_start_loc: np.ndarray, + query_end_loc: np.ndarray, + padding_token_id: int, + parallel_drafting_token_id: int, + num_padding_slots: int, + shift_input_ids: bool, +): + """CPU golden reference for CopyAndExpandEagleInputs. + + Returns: + (out_input_ids, out_positions, out_is_rejected_token_mask, + out_is_masked_token_mask, out_new_token_indices, + out_hidden_state_mapping) + """ + num_reqs = len(next_token_ids) + + # Compute total_draft_tokens + total_draft_tokens = 0 + for r in range(num_reqs): + qs = query_start_loc[r] + nqs = query_start_loc[r + 1] + qe = query_end_loc[r] + num_rejected = max(nqs - qe - 1, 0) + if shift_input_ids: + num_valid = max(qe - qs, 0) + else: + num_valid = max(qe - qs + 1, 0) + total_draft_tokens += num_valid + num_padding_slots + num_rejected + + out_ids = np.zeros(total_draft_tokens, dtype=np.int32) + out_pos = np.zeros(total_draft_tokens, dtype=np.int32) + out_rej = np.zeros(total_draft_tokens, dtype=np.int8) + out_msk = np.zeros(total_draft_tokens, dtype=np.int8) + out_nti = np.zeros(num_reqs * num_padding_slots, dtype=np.int32) + total_input_tokens = len(target_token_ids) + out_hsm = np.zeros(total_input_tokens, dtype=np.int32) + + for r in range(num_reqs): + qs = query_start_loc[r] + nqs = query_start_loc[r + 1] + qe = query_end_loc[r] + + num_rejected = max(nqs - qe - 1, 0) + + if shift_input_ids: + num_valid = max(qe - qs, 0) + output_start = qs + r * (num_padding_slots - 1) + else: + num_valid = max(qe - qs + 1, 0) + output_start = qs + r * num_padding_slots + + start_pos = target_positions[qs] + next_token_id = next_token_ids[r] + + # Valid region + if shift_input_ids: + read_start = qs + 1 + read_count = min(num_valid, total_input_tokens - read_start) + if read_count < 0: + read_count = 0 + for j in range(num_valid): + idx = min(j, read_count - 1) if read_count > 0 else 0 + out_ids[output_start + j] = target_token_ids[read_start + idx] if read_count > 0 else 0 + out_pos[output_start + j] = start_pos + j + out_rej[output_start + j] = 0 + out_msk[output_start + j] = 0 + else: + num_input = nqs - qs + for j in range(num_valid): + idx = min(j, num_input - 1) + out_ids[output_start + j] = target_token_ids[qs + idx] + out_pos[output_start + j] = start_pos + j + out_rej[output_start + j] = 0 + out_msk[output_start + j] = 0 + + # Bonus token + out_ids[output_start + num_valid] = next_token_id + out_pos[output_start + num_valid] = start_pos + num_valid + out_rej[output_start + num_valid] = 0 + out_msk[output_start + num_valid] = 0 + + # Parallel draft tokens + for k in range(1, num_padding_slots): + j = num_valid + k + out_ids[output_start + j] = parallel_drafting_token_id + out_pos[output_start + j] = start_pos + j + out_rej[output_start + j] = 0 + out_msk[output_start + j] = 1 + + # Rejected tokens + for k in range(num_rejected): + j = num_valid + num_padding_slots + k + out_ids[output_start + j] = padding_token_id + out_pos[output_start + j] = 0 + out_rej[output_start + j] = 1 + out_msk[output_start + j] = 0 + + # New token indices + for k in range(num_padding_slots): + out_nti[r * num_padding_slots + k] = output_start + num_valid + k + + # Hidden state mapping (shift_input_ids=true only) + if shift_input_ids: + num_input = nqs - qs + for j in range(num_input): + out_hsm[qs + j] = output_start + j + + return out_ids, out_pos, out_rej, out_msk, out_nti, out_hsm + + +# --------------------------------------------------------------------------- +# NPU operator wrapper +# --------------------------------------------------------------------------- + +def npu_op_exec( + target_token_ids, target_positions, next_token_ids, + query_start_loc, query_end_loc, + padding_token_id, parallel_drafting_token_id, + num_padding_slots, shift_input_ids, total_draft_tokens, +): + """Execute the custom Ascend NPU operator.""" + result = torch.ops._C_ascend.npu_copy_and_expand_eagle_inputs( + target_token_ids.to(torch.int32).npu(), + target_positions.to(torch.int32).npu(), + next_token_ids.to(torch.int32).npu(), + query_start_loc.to(torch.int32).npu(), + query_end_loc.to(torch.int32).npu(), + padding_token_id, + parallel_drafting_token_id, + num_padding_slots, + shift_input_ids, + total_draft_tokens, + ) + return tuple(t.cpu() for t in result) + + +# --------------------------------------------------------------------------- +# Test case generator +# --------------------------------------------------------------------------- + +def generate_test_case(rng, num_reqs, num_padding_slots, shift_input_ids, + min_tokens_per_req=2, max_tokens_per_req=64, + max_rejected_per_req=5): + """Generate a random test case. + + Returns dict with all input arrays and expected parameters. + """ + padding_token_id = 0 + parallel_drafting_token_id = 100 + + # Generate per-request token counts + tokens_per_req = rng.integers(min_tokens_per_req, max_tokens_per_req + 1, + size=num_reqs) + rejected_per_req = rng.integers(0, max_rejected_per_req + 1, size=num_reqs) + + # Build query_start_loc (cumulative) + query_start_loc = np.zeros(num_reqs + 1, dtype=np.int32) + for i in range(num_reqs): + query_start_loc[i + 1] = query_start_loc[i] + tokens_per_req[i] + rejected_per_req[i] + + total_input_tokens = int(query_start_loc[num_reqs]) + + # Build query_end_loc: queryEnd = queryStart + numAccepted - 1 + # where numAccepted = tokens_per_req[i] + # For shift=false: numValid = queryEnd - queryStart + 1 = tokens_per_req[i] + # For shift=true: numValid = queryEnd - queryStart = tokens_per_req[i] - 1 + query_end_loc = np.zeros(num_reqs, dtype=np.int32) + for i in range(num_reqs): + if shift_input_ids: + query_end_loc[i] = query_start_loc[i] + tokens_per_req[i] + else: + query_end_loc[i] = query_start_loc[i] + tokens_per_req[i] - 1 + + # Generate input tokens and positions + target_token_ids = rng.integers(1, 50000, size=total_input_tokens, dtype=np.int32) + target_positions = np.zeros(total_input_tokens, dtype=np.int32) + for i in range(num_reqs): + qs = query_start_loc[i] + nqs = query_start_loc[i + 1] + for j in range(nqs - qs): + target_positions[qs + j] = j + + next_token_ids = rng.integers(1, 50000, size=num_reqs, dtype=np.int32) + + # Compute total_draft_tokens + total_draft_tokens = 0 + for r in range(num_reqs): + qs = query_start_loc[r] + nqs = query_start_loc[r + 1] + qe = query_end_loc[r] + num_rejected = max(nqs - qe - 1, 0) + if shift_input_ids: + num_valid = max(qe - qs, 0) + else: + num_valid = max(qe - qs + 1, 0) + total_draft_tokens += num_valid + num_padding_slots + num_rejected + + return { + "target_token_ids": target_token_ids, + "target_positions": target_positions, + "next_token_ids": next_token_ids, + "query_start_loc": query_start_loc, + "query_end_loc": query_end_loc, + "padding_token_id": padding_token_id, + "parallel_drafting_token_id": parallel_drafting_token_id, + "num_padding_slots": num_padding_slots, + "shift_input_ids": shift_input_ids, + "total_draft_tokens": total_draft_tokens, + } + + +# --------------------------------------------------------------------------- +# Parametrized tests +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("num_reqs", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("num_padding_slots", [1, 2, 3, 5]) +@pytest.mark.parametrize("shift_input_ids", [False, True]) +@pytest.mark.parametrize("seed_offset", [0, 1]) +def test_copy_and_expand_eagle_inputs(num_reqs, num_padding_slots, + shift_input_ids, seed_offset): + """Test CopyAndExpandEagleInputs with parametrized configurations.""" + rng = np.random.default_rng(SEED + seed_offset) + + case = generate_test_case(rng, num_reqs, num_padding_slots, + shift_input_ids) + + # Golden reference + g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand( + case["target_token_ids"], + case["target_positions"], + case["next_token_ids"], + case["query_start_loc"], + case["query_end_loc"], + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + ) + + # NPU execution + n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec( + torch.from_numpy(case["target_token_ids"]), + torch.from_numpy(case["target_positions"]), + torch.from_numpy(case["next_token_ids"]), + torch.from_numpy(case["query_start_loc"]), + torch.from_numpy(case["query_end_loc"]), + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + case["total_draft_tokens"], + ) + + # Convert golden to tensors + g_ids_t = torch.from_numpy(g_ids) + g_pos_t = torch.from_numpy(g_pos) + g_rej_t = torch.from_numpy(g_rej) + g_msk_t = torch.from_numpy(g_msk) + g_nti_t = torch.from_numpy(g_nti) + g_hsm_t = torch.from_numpy(g_hsm) + + # Compare outputs + torch.testing.assert_close(n_ids, g_ids_t, atol=0, rtol=0, + msg="out_input_ids mismatch") + torch.testing.assert_close(n_pos, g_pos_t, atol=0, rtol=0, + msg="out_positions mismatch") + torch.testing.assert_close(n_rej, g_rej_t, atol=0, rtol=0, + msg="out_is_rejected_token_mask mismatch") + torch.testing.assert_close(n_msk, g_msk_t, atol=0, rtol=0, + msg="out_is_masked_token_mask mismatch") + torch.testing.assert_close(n_nti, g_nti_t, atol=0, rtol=0, + msg="out_new_token_indices mismatch") + + if shift_input_ids: + torch.testing.assert_close(n_hsm, g_hsm_t, atol=0, rtol=0, + msg="out_hidden_state_mapping mismatch") + + +@pytest.mark.parametrize("num_reqs", [1]) +@pytest.mark.parametrize("num_padding_slots", [1]) +@pytest.mark.parametrize("shift_input_ids", [False, True]) +def test_minimal_case(num_reqs, num_padding_slots, shift_input_ids): + """Test with minimal input (1 request, 1 padding slot).""" + rng = np.random.default_rng(SEED + 100) + case = generate_test_case(rng, num_reqs, num_padding_slots, + shift_input_ids, min_tokens_per_req=2, + max_tokens_per_req=3, max_rejected_per_req=1) + + g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand( + case["target_token_ids"], + case["target_positions"], + case["next_token_ids"], + case["query_start_loc"], + case["query_end_loc"], + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + ) + + n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec( + torch.from_numpy(case["target_token_ids"]), + torch.from_numpy(case["target_positions"]), + torch.from_numpy(case["next_token_ids"]), + torch.from_numpy(case["query_start_loc"]), + torch.from_numpy(case["query_end_loc"]), + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + case["total_draft_tokens"], + ) + + torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0) + torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0) + torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0) + torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0) + torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0) + + +@pytest.mark.parametrize("num_reqs", [3, 7, 13]) +def test_large_tokens_per_request(num_reqs): + """Test with larger token counts per request.""" + rng = np.random.default_rng(SEED + 200) + case = generate_test_case(rng, num_reqs, num_padding_slots=3, + shift_input_ids=False, + min_tokens_per_req=100, + max_tokens_per_req=512, + max_rejected_per_req=10) + + g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand( + case["target_token_ids"], + case["target_positions"], + case["next_token_ids"], + case["query_start_loc"], + case["query_end_loc"], + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + ) + + n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec( + torch.from_numpy(case["target_token_ids"]), + torch.from_numpy(case["target_positions"]), + torch.from_numpy(case["next_token_ids"]), + torch.from_numpy(case["query_start_loc"]), + torch.from_numpy(case["query_end_loc"]), + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + case["total_draft_tokens"], + ) + + torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0) + torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0) + torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0) + torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0) + torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0) + + +@pytest.mark.parametrize("num_reqs", [3, 7, 13]) +def test_large_tokens_shift_true(num_reqs): + """Test with larger token counts and shift_input_ids=True.""" + rng = np.random.default_rng(SEED + 300) + case = generate_test_case(rng, num_reqs, num_padding_slots=4, + shift_input_ids=True, + min_tokens_per_req=50, + max_tokens_per_req=256, + max_rejected_per_req=8) + + g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand( + case["target_token_ids"], + case["target_positions"], + case["next_token_ids"], + case["query_start_loc"], + case["query_end_loc"], + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + ) + + n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec( + torch.from_numpy(case["target_token_ids"]), + torch.from_numpy(case["target_positions"]), + torch.from_numpy(case["next_token_ids"]), + torch.from_numpy(case["query_start_loc"]), + torch.from_numpy(case["query_end_loc"]), + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + case["total_draft_tokens"], + ) + + torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0) + torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0) + torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0) + torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0) + torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0) + torch.testing.assert_close(n_hsm, torch.from_numpy(g_hsm), atol=0, rtol=0) + + +@pytest.mark.parametrize("num_reqs", [1, 4, 8]) +def test_no_rejected_tokens(num_reqs): + """Test cases with zero rejected tokens.""" + rng = np.random.default_rng(SEED + 400) + case = generate_test_case(rng, num_reqs, num_padding_slots=2, + shift_input_ids=False, + min_tokens_per_req=5, + max_tokens_per_req=20, + max_rejected_per_req=0) + + g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand( + case["target_token_ids"], + case["target_positions"], + case["next_token_ids"], + case["query_start_loc"], + case["query_end_loc"], + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + ) + + n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec( + torch.from_numpy(case["target_token_ids"]), + torch.from_numpy(case["target_positions"]), + torch.from_numpy(case["next_token_ids"]), + torch.from_numpy(case["query_start_loc"]), + torch.from_numpy(case["query_end_loc"]), + case["padding_token_id"], + case["parallel_drafting_token_id"], + case["num_padding_slots"], + case["shift_input_ids"], + case["total_draft_tokens"], + ) + + torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0) + torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0) + torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0) + torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0) + torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0) diff --git a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py index bc988c2b..a24808cb 100644 --- a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py @@ -4,7 +4,7 @@ from __future__ import annotations import math import os import random -from typing import Any, Union +from typing import Any import pytest from transformers import AutoTokenizer @@ -17,23 +17,32 @@ from tests.e2e.conftest import VllmRunner os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" MODELS = { - #"eagle": { + # "eagle": { # "main": "LLM-Research/Meta-Llama-3.1-8B-Instruct", # "spec": "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B", - #}, + # }, "eagle3": { "main": "Qwen/Qwen3-8B", "spec": "RedHatAI/Qwen3-8B-speculator.eagle3", }, } +DRAFT_PARALLEL_MODELS = { + "draft_parallel": { + "main": "LLM-Research/Meta-Llama-3.1-8B-Instruct", + "spec": "amd/PARD-Llama-3.2-1B", + }, +} + # NOTE: golden may change (eagle_proposer only runs in eager mode currently), # thus please update it if ci fails but you have better acceptance BASELINES = { "eagle": [0.74, 0.44, 0.29], "eagle3": [0.68, 0.40, 0.18], + "draft_parallel": [0.83, 0.50, 0.33, 0.17, 0.17, 0.17, 0.17, 0.00], } + @pytest.fixture def test_prompts(): prompt_types = ["repeat", "sentence"] @@ -89,6 +98,7 @@ def eagle3_model_name(): def vl_model_name(): return "Qwen/Qwen3-VL-8B-Instruct" + def vl_eagle3_model_name(): return "MNN/Qwen3-VL-8B-Instruct-Eagle3" @@ -98,28 +108,28 @@ def test_ngram_correctness( sampling_config: SamplingParams, model_name: str, ): - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using ngram speculative decoding. - ''' + """ with VllmRunner( - model_name, - max_model_len=1024, - cudagraph_capture_sizes=[1, 2, 4, 8], + model_name, + max_model_len=1024, + cudagraph_capture_sizes=[1, 2, 4, 8], ) as ref_llm: ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) with VllmRunner( - model_name, - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, - max_model_len=1024, - cudagraph_capture_sizes=[1, 2, 4, 8], + model_name, + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + cudagraph_capture_sizes=[1, 2, 4, 8], ) as runner: spec_outputs = runner.model.chat(test_prompts, sampling_config) matches = 0 @@ -142,27 +152,27 @@ def test_qwen3_vl_eagle_correctness( sampling_config: SamplingParams, vl_model_name: str, ): - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. - ''' + """ with VllmRunner( - vl_model_name, - max_model_len=1024, - cudagraph_capture_sizes=[1, 2, 4, 8], + vl_model_name, + max_model_len=1024, + cudagraph_capture_sizes=[1, 2, 4, 8], ) as ref_llm: ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) spec_model_name = vl_eagle3_model_name() with VllmRunner( - vl_model_name, - speculative_config={ - "method": "eagle3", - "model": spec_model_name, - "num_speculative_tokens": 2, - }, - max_model_len=1024, - cudagraph_capture_sizes=[1, 2, 4, 8], + vl_model_name, + speculative_config={ + "method": "eagle3", + "model": spec_model_name, + "num_speculative_tokens": 2, + }, + max_model_len=1024, + cudagraph_capture_sizes=[1, 2, 4, 8], ) as runner: spec_outputs = runner.model.chat(test_prompts, sampling_config) matches = 0 @@ -179,27 +189,28 @@ def test_qwen3_vl_eagle_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + def test_suffix_correctness( test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, ): - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using ngram speculative decoding. - ''' - with VllmRunner(model_name, - max_model_len=1024, - cudagraph_capture_sizes=[1, 2, 4, 8]) as ref_llm: + """ + with VllmRunner(model_name, max_model_len=1024, cudagraph_capture_sizes=[1, 2, 4, 8]) as ref_llm: ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) - with VllmRunner(model_name, - speculative_config={ - "method": "suffix", - "num_speculative_tokens": 8, - }, - cudagraph_capture_sizes=[1, 2, 4, 8], - max_model_len=1024) as runner: + with VllmRunner( + model_name, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 8, + }, + cudagraph_capture_sizes=[1, 2, 4, 8], + max_model_len=1024, + ) as runner: spec_outputs = runner.model.chat(test_prompts, sampling_config) matches = 0 misses = 0 @@ -221,22 +232,24 @@ def test_suffix_acceptance( sampling_config: SamplingParams, model_name: str, ): - ''' + """ Check that suffix decoding caching takes effect and improves acceptance lengths and acceptance rates over multiple runs of the same prompts. - ''' + """ num_draft = [] num_accept = [] - with VllmRunner(model_name, - speculative_config={ - "method": "suffix", - "suffix_decoding_max_spec_factor": 2.0, - "suffix_decoding_max_cached_requests": 1000, - "num_speculative_tokens": 10, - }, - max_model_len=1024, - cudagraph_capture_sizes=[1, 2, 4, 8], - disable_log_stats=False) as runner: + with VllmRunner( + model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + "num_speculative_tokens": 10, + }, + max_model_len=1024, + cudagraph_capture_sizes=[1, 2, 4, 8], + disable_log_stats=False, + ) as runner: for i in range(10): runner.model.chat(test_prompts[i], sampling_config) metrics = runner.model.get_metrics() @@ -271,13 +284,10 @@ def test_suffix_acceptance( def test_eagle_logprobs( model_name: str, use_eagle3: bool, - draft_tensor_parallel_size: Union[None, int], + draft_tensor_parallel_size: None | int, ): prompt = {"role": "user", "content": "Hello world " * 10} - sampling_params = SamplingParams(temperature=0, - logprobs=1, - max_tokens=10, - ignore_eos=False) + sampling_params = SamplingParams(temperature=0, logprobs=1, max_tokens=10, ignore_eos=False) ref_llm = LLM(model=model_name, max_model_len=2048) ref_outputs = ref_llm.chat([prompt], sampling_params) @@ -290,19 +300,19 @@ def test_eagle_logprobs( spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() with VllmRunner( - model_name, - max_num_seqs=1, - max_num_batched_tokens=2048, - gpu_memory_utilization=0.6, - speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", - "model": spec_model_name, - "num_speculative_tokens": 2, - "draft_tensor_parallel_size": draft_tensor_parallel_size, - "max_model_len": 128, - }, - max_model_len=128, - cudagraph_capture_sizes=[1, 2, 4, 8], + model_name, + max_num_seqs=1, + max_num_batched_tokens=2048, + gpu_memory_utilization=0.6, + speculative_config={ + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, + "num_speculative_tokens": 2, + "draft_tensor_parallel_size": draft_tensor_parallel_size, + "max_model_len": 128, + }, + max_model_len=128, + cudagraph_capture_sizes=[1, 2, 4, 8], ) as runner: spec_outputs = runner.model.chat([prompt], sampling_params) @@ -314,10 +324,7 @@ def test_eagle_logprobs( spec_logprobs.append(logprobs[token_id]) for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): - assert math.isclose(ref_logprob.logprob, - spec_logprob.logprob, - rel_tol=5e-2, - abs_tol=1e-1) + assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1) assert ref_logprob.rank == spec_logprob.rank assert ref_logprob.decoded_token == spec_logprob.decoded_token @@ -330,7 +337,7 @@ def test_eagle_logprobs( def test_llama_qwen_eagle_acceptance( method: str, num_speculative_tokens: int, - draft_tensor_parallel_size: Union[None, int], + draft_tensor_parallel_size: None | int, disable_padded_drafter_batch: bool, async_scheduling: bool, ): @@ -375,7 +382,8 @@ def test_llama_qwen_eagle_acceptance( [prompt], tokenize=False, add_generation_prompt=True, - ) for prompt in prompts + ) + for prompt in prompts ] speculative_config = { @@ -389,16 +397,16 @@ def test_llama_qwen_eagle_acceptance( compilation_config = CompilationConfig(cudagraph_capture_sizes=[12]) with VllmRunner( - main_model_name, - max_model_len=2048, - disable_log_stats=False, - tensor_parallel_size=1, - max_num_seqs=256, - distributed_executor_backend="mp", - gpu_memory_utilization=0.7, - speculative_config=speculative_config, - compilation_config=compilation_config, - async_scheduling=async_scheduling, + main_model_name, + max_model_len=2048, + disable_log_stats=False, + tensor_parallel_size=1, + max_num_seqs=256, + distributed_executor_backend="mp", + gpu_memory_utilization=0.7, + speculative_config=speculative_config, + compilation_config=compilation_config, + async_scheduling=async_scheduling, ) as llm: outputs = llm.model.generate(prompts, sampling_params) metrics = llm.model.get_metrics() @@ -419,10 +427,7 @@ def test_llama_qwen_eagle_acceptance( for pos in range(len(metric.values)): num_accepted_tokens_per_pos[pos] += metric.values[pos] - acceptance_per_pos = [ - num_accepted_tokens / num_drafts - for num_accepted_tokens in num_accepted_tokens_per_pos - ] + acceptance_per_pos = [num_accepted_tokens / num_drafts for num_accepted_tokens in num_accepted_tokens_per_pos] if method == "eagle": golden = [0.7313432835820896, 0.373134328358209, 0.19402985074626866] else: @@ -434,3 +439,98 @@ def test_llama_qwen_eagle_acceptance( print(f"golden: {golden}") assert match + + +@pytest.mark.parametrize("method", DRAFT_PARALLEL_MODELS.keys()) +@pytest.mark.parametrize("num_speculative_tokens", [8]) +@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1]) +def test_parallel_drafting_acceptance( + method: str, + num_speculative_tokens: int, + draft_tensor_parallel_size: None | int, +): + """ + Test acceptance rate for parallel drafting speculative decoding + using a smaller draft model with parallel_drafting enabled. + """ + main_model_name = DRAFT_PARALLEL_MODELS[method]["main"] + spec_model_name = DRAFT_PARALLEL_MODELS[method]["spec"] + + tokenizer = AutoTokenizer.from_pretrained( + main_model_name, + trust_remote_code=True, + ) + sampling_params = SamplingParams( + temperature=0, + ignore_eos=False, + max_tokens=256, + ) + + prompts = [ + { + "role": "user", + "content": "Hello, your name is", + }, + ] + prompts = [ + tokenizer.apply_chat_template( + [prompt], + tokenize=False, + add_generation_prompt=True, + ) + for prompt in prompts + ] + + speculative_config = { + "method": "draft_model", + "model": spec_model_name, + "num_speculative_tokens": num_speculative_tokens, + "draft_tensor_parallel_size": draft_tensor_parallel_size, + "parallel_drafting": True, + } + + compilation_config = CompilationConfig(cudagraph_capture_sizes=[12]) + + with VllmRunner( + main_model_name, + max_model_len=4096, + disable_log_stats=False, + tensor_parallel_size=1, + max_num_seqs=256, + distributed_executor_backend="mp", + gpu_memory_utilization=0.8, + speculative_config=speculative_config, + compilation_config=compilation_config, + enable_prefix_caching=False, + ) as llm: + outputs = llm.model.generate(prompts, sampling_params) + metrics = llm.model.get_metrics() + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + output_tokens = output.outputs[0].token_ids + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Output tokens: {output_tokens}") + + num_drafts = 0 + num_accepted_tokens_per_pos = [0] * num_speculative_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + num_accepted_tokens_per_pos[pos] += metric.values[pos] + + acceptance_per_pos = [num_accepted_tokens / num_drafts for num_accepted_tokens in num_accepted_tokens_per_pos] + + golden = BASELINES[method] + + match = all(abs(a - b) < 0.1 for a, b in zip(acceptance_per_pos, golden)) + if not match: + print(f"acceptance_per_pos: {acceptance_per_pos}") + print(f"golden: {golden}") + + assert match diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 66dfc8a1..e7e1ea68 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -10,14 +10,15 @@ from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer class TestEagleProposerInitialization(TestBase): - def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() self.vllm_config.cache_config = MagicMock(spec=CacheConfig) self.vllm_config.scheduler_config = MagicMock() self.vllm_config.model_config = MagicMock() - self.vllm_config.model_config.hf_text_config = MagicMock(spec=[]) # Empty spec to prevent hasattr from returning True + self.vllm_config.model_config.hf_text_config = MagicMock( + spec=[] + ) # Empty spec to prevent hasattr from returning True self.vllm_config.model_config.hf_text_config.to_dict = MagicMock(return_value={}) self.vllm_config.compilation_config = MagicMock() self.device = torch.device("cpu") @@ -40,20 +41,16 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.parallel_config.enable_expert_parallel = False self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.num_speculative_tokens = 2 - self.vllm_config.speculative_config.speculative_token_tree = str([ - (i + 1) * (0, ) for i in range(2) - ]) + self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(2)]) self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.additional_config = None - self.mock_cpugpubuffer = patch( - "vllm.v1.spec_decode.eagle.CpuGpuBuffer") + self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( - "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", - return_value=False + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False ) self.mock_supports_multimodal_inputs.start() @@ -78,18 +75,16 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = AscendEagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.hidden_size, 4096) self.assertTrue(proposer.use_cuda_graph) expected_max_num_tokens = proposer.max_num_tokens - self.assertEqual(proposer.input_ids.shape, (expected_max_num_tokens, )) - self.assertEqual(proposer.positions.shape, (expected_max_num_tokens, )) + self.assertEqual(proposer.input_ids.shape, (expected_max_num_tokens,)) + self.assertEqual(proposer.positions.shape, (expected_max_num_tokens,)) self.assertEqual(proposer.hidden_states.shape, (expected_max_num_tokens, 4096)) - self.assertEqual(proposer.arange.shape, (expected_max_num_tokens, )) + self.assertEqual(proposer.arange.shape, (expected_max_num_tokens,)) def test_initialization_eagle3_enforce_eager(self): self.vllm_config.speculative_config.method = "eagle3" @@ -101,9 +96,7 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = AscendEagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.hidden_size, 2048) self.assertFalse(proposer.use_cuda_graph) @@ -120,9 +113,7 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = AscendEagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.hidden_size, 2048) self.assertTrue(proposer.use_cuda_graph) @@ -139,9 +130,7 @@ class TestEagleProposerInitialization(TestBase): init_ascend_config(self.vllm_config) with set_current_vllm_config(self.vllm_config): - proposer = AscendEagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) + proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.assertEqual(proposer.hidden_size, 2048) self.assertFalse(proposer.use_cuda_graph) @@ -150,7 +139,6 @@ class TestEagleProposerInitialization(TestBase): class TestEagleProposerLoadModel(TestBase): - def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() @@ -175,29 +163,24 @@ class TestEagleProposerLoadModel(TestBase): self.vllm_config.parallel_config.enable_expert_parallel = False self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.num_speculative_tokens = 2 - self.vllm_config.speculative_config.speculative_token_tree = str([ - (i + 1) * (0, ) for i in range(2) - ]) + self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(2)]) self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.additional_config = None init_ascend_config(self.vllm_config) - self.mock_cpugpubuffer = patch( - "vllm.v1.spec_decode.eagle.CpuGpuBuffer") + self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( - "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", - return_value=False + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False ) self.mock_supports_multimodal_inputs.start() # Set the current vllm config set_current_vllm_config(self.vllm_config) - self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) + self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) + self.proposer.parallel_drafting = False def tearDown(self): self.mock_cpugpubuffer.stop() @@ -205,24 +188,21 @@ class TestEagleProposerLoadModel(TestBase): # Clear the current vllm config set_current_vllm_config(None) - @patch( - "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") - def test_load_model_pp1(self, mock_pp_group, mock_get_model, - mock_get_layers): + def test_load_model_pp1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 1 mock_target_layer1 = MagicMock() mock_target_layer2 = MagicMock() mock_draft_layer1 = MagicMock() mock_draft_layer3 = MagicMock() - mock_get_layers.side_effect = [{ - "layer1": mock_target_layer1, - "layer2": mock_target_layer2 - }, {}, {}, { - "layer1": mock_draft_layer1, - "layer3": mock_draft_layer3 - }] + mock_get_layers.side_effect = [ + {"layer1": mock_target_layer1, "layer2": mock_target_layer2}, + {}, + {}, + {"layer1": mock_draft_layer1, "layer3": mock_draft_layer3}, + ] weight = torch.zeros(0) @@ -241,61 +221,45 @@ class TestEagleProposerLoadModel(TestBase): self.proposer.load_model(mock_model) mock_get_model.assert_called_once() self.assertEqual(self.proposer.attn_layer_names, ["layer3"]) - self.assertIs(self.proposer.model.model.embed_tokens, - mock_model.model.embed_tokens) + self.assertIs(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) - @patch( - "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") - def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, - mock_get_layers): + def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 2 mock_target_layer1 = MagicMock() mock_draft_layer2 = MagicMock() - mock_get_layers.side_effect = [{ - "layer1": mock_target_layer1 - }, {}, {}, { - "layer2": mock_draft_layer2 - }] + mock_get_layers.side_effect = [{"layer1": mock_target_layer1}, {}, {}, {"layer2": mock_draft_layer2}] mock_model = MagicMock() original_embed = MagicMock() mock_model.multimodal_cpu_fields = None mock_model.merge_by_field_config = None - mock_get_model.return_value = MagicMock(model=MagicMock( - embed_tokens=original_embed)) + mock_get_model.return_value = MagicMock(model=MagicMock(embed_tokens=original_embed)) with set_current_vllm_config(self.vllm_config): self.proposer.load_model(mock_model) - self.assertIsNot(self.proposer.model.model.embed_tokens, - mock_model.model.embed_tokens) + self.assertIsNot(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) self.assertEqual(self.proposer.attn_layer_names, ["layer2"]) - @patch( - "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") @patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal") - def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, - mock_get_model, mock_get_layers): + def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, mock_get_model, mock_get_layers): mock_model = MagicMock() mock_model.get_language_model.return_value.lm_head = MagicMock() mock_supports_multi.return_value = True original_embed = MagicMock() - mock_get_model.return_value = MagicMock(model=MagicMock( - embed_tokens=original_embed)) + mock_get_model.return_value = MagicMock(model=MagicMock(embed_tokens=original_embed)) mock_target_layer1 = MagicMock() mock_draft_layer2 = MagicMock() - mock_get_layers.side_effect = [{ - "layer1": mock_target_layer1 - }, {}, {}, { - "layer2": mock_draft_layer2 - }] + mock_get_layers.side_effect = [{"layer1": mock_target_layer1}, {}, {}, {"layer2": mock_draft_layer2}] mock_pp_group.return_value.world_size = 2 self.proposer.model = MagicMock() @@ -303,12 +267,10 @@ class TestEagleProposerLoadModel(TestBase): with set_current_vllm_config(self.vllm_config): self.proposer.load_model(mock_model) self.assertEqual(mock_model.get_language_model.call_count, 2) - self.assertIs(self.proposer.model.lm_head, - mock_model.get_language_model.return_value.lm_head) + self.assertIs(self.proposer.model.lm_head, mock_model.get_language_model.return_value.lm_head) class TestEagleProposerDummyRun(TestBase): - def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.speculative_config = MagicMock() @@ -328,51 +290,43 @@ class TestEagleProposerDummyRun(TestBase): self.vllm_config.model_config.uses_mrope = False self.vllm_config.model_config.uses_xdrope_dim = 0 self.vllm_config.model_config.use_mla = False - self.vllm_config.model_config.hf_text_config = MagicMock(spec=[]) # Empty spec to prevent hasattr from returning True + self.vllm_config.model_config.hf_text_config = MagicMock( + spec=[] + ) # Empty spec to prevent hasattr from returning True self.vllm_config.model_config.hf_text_config.to_dict = MagicMock(return_value={}) self.vllm_config.parallel_config.tensor_parallel_size = 1 self.vllm_config.parallel_config.data_parallel_rank = 0 self.vllm_config.parallel_config.data_parallel_size = 1 self.vllm_config.parallel_config.prefill_context_parallel_size = 1 self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 - self.vllm_config.speculative_config.speculative_token_tree = str([ - (i + 1) * (0, ) for i in range(4) - ]) + self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(4)]) self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.additional_config = None init_ascend_config(self.vllm_config) - self.mock_cpugpubuffer = patch( - "vllm.v1.spec_decode.eagle.CpuGpuBuffer") + self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( - "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", - return_value=False + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False ) self.mock_supports_multimodal_inputs.start() # Mock parallel state functions self.mock_tp_world_size = patch( - "vllm_ascend.ascend_forward_context.get_tensor_model_parallel_world_size", - return_value=1 + "vllm_ascend.ascend_forward_context.get_tensor_model_parallel_world_size", return_value=1 ) self.mock_tp_world_size.start() mock_dp_group = MagicMock() mock_dp_group.world_size = 1 - self.mock_dp_group = patch( - "vllm_ascend.ascend_forward_context.get_dp_group", - return_value=mock_dp_group - ) + self.mock_dp_group = patch("vllm_ascend.ascend_forward_context.get_dp_group", return_value=mock_dp_group) self.mock_dp_group.start() # Set the current vllm config set_current_vllm_config(self.vllm_config) - self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) + self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.proposer.model = MagicMock() self.proposer._runnable = MagicMock() self.proposer.update_stream = MagicMock() @@ -397,8 +351,7 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` with set_current_vllm_config(self.vllm_config): self.proposer.enable_shared_expert_dp = False - self.proposer.dummy_run(num_tokens=num_tokens, - with_prefill=with_prefill) + self.proposer.dummy_run(num_tokens=num_tokens, with_prefill=with_prefill) self.assertTrue(self.proposer._runnable.call_count == 1) @@ -433,9 +386,7 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` with set_current_vllm_config(self.vllm_config): self.proposer.enable_shared_expert_dp = False - self.proposer.dummy_run(num_tokens=64, - in_graph_capturing=True, - aclgraph_runtime_mode=CUDAGraphMode.FULL) + self.proposer.dummy_run(num_tokens=64, in_graph_capturing=True, aclgraph_runtime_mode=CUDAGraphMode.FULL) self.assertTrue(self.proposer._runnable.call_count == 1) mock_update_full_graph_params.assert_not_called() self.proposer.use_cuda_graph = last_use_cuda_graph @@ -458,16 +409,13 @@ class TestEagleProposerDummyRun(TestBase): # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` with set_current_vllm_config(self.vllm_config): self.proposer.enable_shared_expert_dp = False - self.proposer.dummy_run(num_tokens=64, - in_graph_capturing=False, - aclgraph_runtime_mode=CUDAGraphMode.FULL) + self.proposer.dummy_run(num_tokens=64, in_graph_capturing=False, aclgraph_runtime_mode=CUDAGraphMode.FULL) self.assertTrue(self.proposer._runnable.call_count == 1) self.assertTrue(mock_update_full_graph_params.call_count == 1) self.proposer.use_cuda_graph = last_use_cuda_graph class TestEagleProposerHelperMethods(TestBase): - # TODO: Can add some tests about prepare_next_token_ids in future. def setUp(self): @@ -497,29 +445,23 @@ class TestEagleProposerHelperMethods(TestBase): self.vllm_config.parallel_config.enable_expert_parallel = False self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.num_speculative_tokens = 2 - self.vllm_config.speculative_config.speculative_token_tree = str([ - (i + 1) * (0, ) for i in range(2) - ]) + self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(2)]) self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.additional_config = None init_ascend_config(self.vllm_config) - self.mock_cpugpubuffer = patch( - "vllm.v1.spec_decode.eagle.CpuGpuBuffer") + self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") self.mock_cpugpubuffer.start() self.mock_supports_multimodal_inputs = patch( - "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", - return_value=False + "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False ) self.mock_supports_multimodal_inputs.start() # Set the current vllm config set_current_vllm_config(self.vllm_config) - self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) + self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) def tearDown(self): self.mock_cpugpubuffer.stop() @@ -536,11 +478,9 @@ class TestEagleProposerHelperMethods(TestBase): num_rejected = torch.tensor([1, 0, 1], device=self.device) mock_return_attn = MagicMock() - with set_current_vllm_config(self.vllm_config): - with patch.object(self.proposer, - 'prepare_inputs', - return_value=(mock_return_attn, - torch.tensor([1, 2, 4]))): - return_attn, indices = self.proposer.prepare_inputs( - mock_attn, num_rejected) - self.assertEqual(indices.tolist(), [1, 2, 4]) + with ( + set_current_vllm_config(self.vllm_config), + patch.object(self.proposer, "prepare_inputs", return_value=(mock_return_attn, torch.tensor([1, 2, 4]))), + ): + return_attn, indices = self.proposer.prepare_inputs(mock_attn, num_rejected) + self.assertEqual(indices.tolist(), [1, 2, 4]) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index ce9d3e1b..b887b114 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -284,6 +284,9 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): if isinstance(self.kv_cache_spec, CrossAttentionSpec): seq_lens = common_attn_metadata.seq_lens slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32) + elif self.speculative_config and self.speculative_config.parallel_drafting: + seq_lens = common_attn_metadata.seq_lens + attn_state = common_attn_metadata.attn_state # Get attn_mask and swa_mask from singleton AttentionMaskBuilder diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py index a66566da..15d42fa0 100644 --- a/vllm_ascend/ops/triton/spec_decode/utils.py +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -24,6 +24,7 @@ def prepare_inputs_padded_kernel( valid_sampled_tokens_count_ptr, # [num_reqs] query_start_loc_gpu_ptr, # [num_reqs + 1] token_indices_to_sample_ptr, # [num_reqs] (output) + num_rejected_tokens_gpu_ptr, num_reqs, # tl.int32 BLOCK_SIZE: tl.constexpr, ): @@ -61,3 +62,4 @@ def prepare_inputs_padded_kernel( index_to_sample = q_last_tok_idx - num_rejected tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask) + tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask) diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 78644448..c17e9398 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -16,6 +16,8 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # + +from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer @@ -31,5 +33,7 @@ def get_spec_decode_method(method, vllm_config, device, runner): return AscendMedusaProposer(vllm_config, device) elif method in ("eagle", "eagle3", "mtp"): return AscendEagleProposer(vllm_config, device, runner) + elif method == "draft_model": + return AscendDraftModelProposer(vllm_config, device, runner) else: raise ValueError(f"Unknown speculative decoding method: {method}") diff --git a/vllm_ascend/spec_decode/draft_proposer.py b/vllm_ascend/spec_decode/draft_proposer.py new file mode 100644 index 00000000..65348b20 --- /dev/null +++ b/vllm_ascend/spec_decode/draft_proposer.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +from typing_extensions import override +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model + +from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer + +logger = init_logger(__name__) + + +class AscendDraftModelProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=False, + runner=runner, + ) + self._raise_if_vocab_size_mismatch() + self._raise_if_draft_tp_mismatch() + + def _raise_if_vocab_size_mismatch(self): + self.speculative_config.verify_equal_vocab_size_if_draft_model() + + def _raise_if_draft_tp_mismatch(self): + # Note(Tomas Ruiz) If we run the target model with TP > 1 and + # the draft model with TP = 1, then the different TP ranks collide. + # Specifically when all ranks compile the draft model on rank 0 + # (because TP=1), then the torch compile cache is overwritten and corrupted. + # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414 + # To prevent this error, we assert that both TP sizes must be the same. + spec_cfg = self.speculative_config + tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size + draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size + if draft_tp != tgt_tp: + raise ValueError( + f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' " + f"must be the same. Got {draft_tp} and {tgt_tp}. " + "Please pass 'draft_tensor_parallel_size' in the speculative_config." + ) + + def _get_model(self) -> nn.Module: + # Draft models may be quantized or on different parallelism, + # so we load them with a modified vllm config + from vllm.compilation.backends import set_model_tag + + temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config) + with set_model_tag("draft_model"): + model = get_model( + vllm_config=temp_vllm_config, + prefix="draft_model", + ) + return model + + @override + def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None: + # Draft models don't share embeddings with the target model + pass + + @override + def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None: + # Draft models don't share lm_head with the target model + pass diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index c567b161..a60c2cef 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -30,8 +30,13 @@ from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer +from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.utils import ( + PADDING_SLOT_ID, + compute_new_slot_mapping, + extend_all_queries_by_N, +) from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import _EXTRA_CTX, set_ascend_forward_context @@ -80,14 +85,14 @@ def split_inputs_tp_to_sp(hidden_states, out): return out[:padded_num_tokens_per_rank] -class AscendEagleProposer(EagleProposer): +class SpecDecodeBaseProposer(EagleProposer): _runnable: ACLGraphWrapper | Callable - def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): + def __init__(self, vllm_config: VllmConfig, device: torch.device, pass_hidden_states_to_model: bool, runner=None): super().__init__(vllm_config, device, runner) self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling - + self.pass_hidden_states_to_model = pass_hidden_states_to_model self.decode_threshold = 1 + self.num_speculative_tokens self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32) self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32) @@ -140,7 +145,7 @@ class AscendEagleProposer(EagleProposer): if not self.use_cuda_graph and enable_sp(vllm_config): self.maybe_eager_context = _maybe_eager_context(vllm_config) - self.last_token_indices = torch.zeros( + self.token_indices_to_sample = torch.zeros( self.vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.int32, device=device ) slot_mapping_lens = self.runner.max_num_tokens + 2 * self.pcp_size * self.runner.max_num_reqs @@ -150,15 +155,38 @@ class AscendEagleProposer(EagleProposer): ] self._runnable = self._run_merged_draft + if self.uses_mrope: + self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int32, device=device) + elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: + self.xdrope_positions = torch.zeros( + (self.uses_xdrope_dim, self.max_num_tokens + 1), + dtype=torch.int32, + device=device, + ) + else: + # RoPE need (max_num_tokens,) + self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) + + def _get_model(self) -> nn.Module: + """ + Default method to call get_model(). Can be overridden by subclasses which + need to customize model loading. + """ + from vllm.compilation.backends import set_model_tag + + with set_model_tag("eagle_head"): + model = get_model( + vllm_config=self.vllm_config, + model_config=self.vllm_config.speculative_config.draft_model_config, + ) + return model def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set(get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()) target_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()) with self.maybe_eager_context: - self.model = get_model( - vllm_config=self.vllm_config, model_config=self.vllm_config.speculative_config.draft_model_config - ) + self.model = self._get_model() indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys() draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) @@ -167,7 +195,7 @@ class AscendEagleProposer(EagleProposer): draft_attn_layer_names = draft_attn_layers - target_attn_layer_names draft_indexer_layer_names = indexer_layers - target_indexer_layer_names draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names - assert len(draft_attn_layer_names) == 1 + self.attn_layer_names = list(sorted(draft_attn_layer_names)) self.kernel_block_size = ( @@ -202,6 +230,24 @@ class AscendEagleProposer(EagleProposer): target_language_model = model # share embed_tokens with the target model if needed + self._maybe_share_embeddings(target_language_model) + self._maybe_share_lm_head(model) + + if self.parallel_drafting and self.pass_hidden_states_to_model: + assert self.parallel_drafting_hidden_state_tensor is not None + self.parallel_drafting_hidden_state_tensor.copy_( + self.model.combine_hidden_states(self.model.mask_hidden.view(3 * self.hidden_size)) + if self.eagle3_use_aux_hidden_state + else self.model.mask_hidden.view(self.hidden_size) + ) + + def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None: + """ + Some draft models may not have their own embedding layers, and some may + have a duplicate copy of the target model's embedding layers. In these cases, + we share the target model's embedding layers with the draft model to save + memory. + """ if get_pp_group().world_size == 1: if hasattr(target_language_model.model, "embed_tokens"): target_embed_tokens = target_language_model.model.embed_tokens @@ -256,7 +302,9 @@ class AscendEagleProposer(EagleProposer): "Since PP > 1 or other reasons the model head loaded its own vocab embedding" " weights instead of sharing them with the target model." ) - # share lm_head with the target model if needed + + # share lm_head with the target model if needed + def _maybe_share_lm_head(self, model: nn.Module) -> None: # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM if self.method == "eagle" and hasattr(model, "lm_head"): @@ -389,7 +437,7 @@ class AscendEagleProposer(EagleProposer): self._runnable( num_input_tokens=num_tokens, batch_size=batch_size, - last_token_indices=self.last_token_indices[:batch_size], + token_indices_to_sample=self.token_indices_to_sample[: batch_size * self.extra_slots_per_request], # The target_position's address is same as the model_positions's target_positions=model_positions, inputs_embeds=None, @@ -411,7 +459,7 @@ class AscendEagleProposer(EagleProposer): target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor | None, + token_indices_to_sample: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, @@ -421,31 +469,34 @@ class AscendEagleProposer(EagleProposer): num_decode_reqs=0, scheduler_output: SchedulerOutput = None, num_scheduled_tokens: int = 0, + num_rejected_tokens_gpu: torch.Tensor | None = None, ) -> torch.Tensor: - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] + batch_size = common_attn_metadata.batch_size() - if last_token_indices is None: - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + if token_indices_to_sample is None: + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.get_model(), Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states(target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=token_indices_to_sample, + cad=common_attn_metadata, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) assert self.runner is not None # update pcp related params if self.pcp_size * self.dcp_size > 1: assert long_seq_metadata is not None common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata - ori_last_token_indices = last_token_indices.clone() + ori_token_indices_to_sample = token_indices_to_sample.clone() query_lens_d = self.runner.query_lens[:num_decode_reqs] if self.pcp_size > 1: # 1. preprocess decode/prefill input_ids & target_hidden_states @@ -484,9 +535,11 @@ class AscendEagleProposer(EagleProposer): target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) # 2. update sample_indices according to main model if num_decode_reqs: - last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]] + token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[ + token_indices_to_sample[:num_decode_reqs] + ] if num_prefill_reqs: - last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] + token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] # 3. update attn_metadata params that may be influenced by pcp common_attn_metadata.num_actual_tokens = num_tokens common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p) @@ -530,10 +583,6 @@ class AscendEagleProposer(EagleProposer): aclgraph_runtime_mode = CUDAGraphMode.NONE batch_descriptor = None - # copy inputs to buffer for cudagraph - self._set_positions(num_tokens, target_positions) - self.hidden_states[:num_tokens] = target_hidden_states - if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) inputs_embeds = self.model.embed_input_ids( @@ -559,15 +608,16 @@ class AscendEagleProposer(EagleProposer): attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) if self.uses_mrope: - used_update_positions = target_positions[:, last_token_indices] + used_update_positions = self.mrope_positions[:, token_indices_to_sample] else: - used_update_positions = target_positions[last_token_indices] + used_update_positions = self.positions[token_indices_to_sample] per_layer_attn_metadata = dict() # The first step of speculative. for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata multi_steps_attn_metadata = [per_layer_attn_metadata] + # Copy the old attn_metadata and update attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]] if self.pcp_size * self.dcp_size > 1: if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills: @@ -578,7 +628,7 @@ class AscendEagleProposer(EagleProposer): # to get corresponding slot_mapping in each step. num_reject_tokens = ( torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device) - - ori_last_token_indices + - ori_token_indices_to_sample - 1 ) num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens @@ -616,6 +666,27 @@ class AscendEagleProposer(EagleProposer): common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size] # Copy the old attn_metadata and update + if not self.parallel_drafting: + for draft_step in range(1, self.num_speculative_tokens): + common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( + draft_step, + attn_metadata, + common_attn_metadata, + batch_size, + num_input_tokens, + used_update_positions, + aclgraph_runtime_mode, + ori_seq_len, + slot_indices, + mtp_slot_mapping, + ) + per_layer_attn_metadata = dict() + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + multi_steps_attn_metadata.append(per_layer_attn_metadata) + else: + # Copy the old attn_metadata and update + if not self.parallel_drafting: for draft_step in range(1, self.num_speculative_tokens): common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( draft_step, @@ -625,33 +696,14 @@ class AscendEagleProposer(EagleProposer): num_input_tokens, used_update_positions, aclgraph_runtime_mode, - ori_seq_len, - slot_indices, - mtp_slot_mapping, ) per_layer_attn_metadata = dict() for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata multi_steps_attn_metadata.append(per_layer_attn_metadata) - else: - # Copy the old attn_metadata and update - for draft_step in range(1, self.num_speculative_tokens): - common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( - draft_step, - attn_metadata, - common_attn_metadata, - batch_size, - num_input_tokens, - used_update_positions, - aclgraph_runtime_mode, - ) - per_layer_attn_metadata = dict() - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata - multi_steps_attn_metadata.append(per_layer_attn_metadata) - last_token_indices_len = last_token_indices.shape[0] - self.last_token_indices[:last_token_indices_len].copy_(last_token_indices) + token_indices_to_sample_len = token_indices_to_sample.shape[0] + self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample) with set_ascend_forward_context( multi_steps_attn_metadata[0], @@ -672,7 +724,7 @@ class AscendEagleProposer(EagleProposer): draft_token_ids = self._runnable( num_input_tokens=num_input_tokens, batch_size=batch_size, - last_token_indices=self.last_token_indices[:last_token_indices_len], + token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len], target_positions=target_positions, inputs_embeds=inputs_embeds, multi_steps_attn_metadata=multi_steps_attn_metadata, @@ -689,7 +741,7 @@ class AscendEagleProposer(EagleProposer): self, num_input_tokens, batch_size, - last_token_indices, + token_indices_to_sample, target_positions, inputs_embeds, multi_steps_attn_metadata, @@ -702,17 +754,22 @@ class AscendEagleProposer(EagleProposer): # `model_hidden_states` represent the speculative model inputs. model_input_ids = self.input_ids[:num_input_tokens] model_positions = self._get_positions(num_input_tokens) - model_hidden_states = self.hidden_states[:num_input_tokens] - model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) + model_kwargs = { + "input_ids": model_input_ids, + "positions": model_positions, + "inputs_embeds": inputs_embeds, + } - ret_hidden_states = self.model( - input_ids=model_input_ids, - positions=model_positions, - hidden_states=model_hidden_states, - inputs_embeds=inputs_embeds, - ) - if self.method == "mtp": + if self.pass_hidden_states_to_model: + model_hidden_states = self.hidden_states[:num_input_tokens] + model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) + model_kwargs["hidden_states"] = model_hidden_states + if self.method == "mtp": + model_kwargs["positions"] = model_positions + + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -722,6 +779,7 @@ class AscendEagleProposer(EagleProposer): last_hidden_states, model_positions, hidden_states ) + num_indices = token_indices_to_sample.shape[0] if self.pcp_size > 1: # remove graph padding before all_gather hidden_states = hidden_states[:num_tokens] @@ -741,26 +799,27 @@ class AscendEagleProposer(EagleProposer): self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]], ) - num_indices = last_token_indices.shape[0] if lmhead_tp_enable() and not is_dummy: max_num_reqs_across_dp = ( self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len ) - last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + token_indices_to_sample = nn.functional.pad( + token_indices_to_sample, (0, max_num_reqs_across_dp - num_indices) + ) - sample_hidden_states = last_hidden_states[last_token_indices] + sample_hidden_states = last_hidden_states[token_indices_to_sample] logits = self.model.compute_logits(sample_hidden_states) if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: logits = logits[:num_indices] - last_token_indices = last_token_indices[:num_indices] + token_indices_to_sample = token_indices_to_sample[:num_indices] draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: + if self.num_speculative_tokens == 1 or self.parallel_drafting: # [batch_size, 1] - return draft_token_ids.view(-1, 1) + return draft_token_ids.view(-1, self.num_speculative_tokens) if self.pcp_size * self.dcp_size > 1 and is_prefill: draft_token_ids = logits.argmax(dim=-1) @@ -775,11 +834,11 @@ class AscendEagleProposer(EagleProposer): ) draft_token_ids_tensor[0] = draft_token_ids if self.uses_mrope: - positions = target_positions[:, last_token_indices] + positions = self.mrope_positions[:, token_indices_to_sample] else: - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] - last_token_indices = self.arange[:batch_size] + positions = self.positions[token_indices_to_sample] + hidden_states = hidden_states[token_indices_to_sample] + token_indices_to_sample = self.arange[:batch_size] input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size @@ -843,13 +902,17 @@ class AscendEagleProposer(EagleProposer): forward_context.attn_metadata = ( multi_steps_attn_metadata[draft_step + 1] if multi_steps_attn_metadata else None ) - ret_hidden_states = self.model( - input_ids=model_input_ids, - positions=model_positions, - hidden_states=model_hidden_states, - inputs_embeds=inputs_embeds, - ) - if self.method == "mtp": + + model_kwargs = { + "input_ids": model_input_ids, + "positions": model_positions, + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = model_hidden_states + + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -859,22 +922,22 @@ class AscendEagleProposer(EagleProposer): last_hidden_states, model_positions, hidden_states ) - num_indices = last_token_indices.shape[0] + num_indices = token_indices_to_sample.shape[0] if lmhead_tp_enable() and not is_dummy: max_num_reqs_across_dp = ( self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len ) - last_token_indices = nn.functional.pad( - last_token_indices, + token_indices_to_sample = nn.functional.pad( + token_indices_to_sample, (0, max_num_reqs_across_dp - num_indices), ) - sample_hidden_states = last_hidden_states[last_token_indices] + sample_hidden_states = last_hidden_states[token_indices_to_sample] logits = self.model.compute_logits(sample_hidden_states) if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: logits = logits[:num_indices] - last_token_indices = last_token_indices[:num_indices] + token_indices_to_sample = token_indices_to_sample[:num_indices] # TODO(wenlong): get more than one token for tree attention hidden_states = hidden_states[:batch_size] @@ -885,6 +948,122 @@ class AscendEagleProposer(EagleProposer): draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) return draft_token_ids + def set_inputs_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + token_indices_to_sample: torch.Tensor | None, + cad: CommonAttentionMetadata, + num_rejected_tokens_gpu: torch.Tensor | None, + ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + if not self.needs_extra_input_slots: + # Default EAGLE pathway: no reshaping of input tensors needed. + # Simply rotate the input ids and leave the positions unchanged, + # Inserting the next token ids at the last slot in each request. + if token_indices_to_sample is None: + token_indices_to_sample = cad.query_start_loc[1:] - 1 + + num_tokens = target_token_ids.shape[0] + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[token_indices_to_sample] = next_token_ids + + # copy inputs to buffer for cudagraph + if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0: + target_positions = target_positions[0] + + self._set_positions(num_tokens, target_positions) + self.hidden_states[:num_tokens] = target_hidden_states + + return num_tokens, token_indices_to_sample, cad + else: + assert self.is_rejected_token_mask is not None + assert self.is_masked_token_mask is not None + # 1. + # Call the CopyAndExpandEagleInputs AscendC operator to copy + # input_ids and positions into the correct slots in the + # preallocated buffers self.input_ids, self.positions. + batch_size = cad.batch_size() + total_num_input_tokens = target_token_ids.shape[0] + total_num_output_tokens = total_num_input_tokens + (self.net_num_new_slots_per_request * batch_size) + + query_start_loc = cad.query_start_loc + query_end_loc = cad.query_start_loc[1:] - 1 + if num_rejected_tokens_gpu is not None: + query_end_loc = query_end_loc - num_rejected_tokens_gpu + + ( + out_input_ids, + out_positions, + out_is_rejected_token_mask, + out_is_masked_token_mask, + token_indices_to_sample, + out_hidden_state_mapping, + ) = torch.ops._C_ascend.npu_copy_and_expand_eagle_inputs( + target_token_ids, + target_positions.to(torch.int32), + next_token_ids, + query_start_loc, + query_end_loc, + 0, # padding_token_id + self.parallel_drafting_token_id, + self.extra_slots_per_request, + self.pass_hidden_states_to_model, + total_num_output_tokens, + ) + + # Copy returned tensors into pre-allocated buffers + self.input_ids[:total_num_output_tokens].copy_(out_input_ids) + self.positions[:total_num_output_tokens].copy_(out_positions) + self.is_rejected_token_mask[:total_num_output_tokens].copy_(out_is_rejected_token_mask) + self.is_masked_token_mask[:total_num_output_tokens].copy_(out_is_masked_token_mask) + if self.pass_hidden_states_to_model: + assert self.parallel_drafting_hidden_state_tensor is not None + self.hidden_states[out_hidden_state_mapping] = target_hidden_states + # Use torch.where to avoid DtoH sync from boolean indexing + mask = self.is_masked_token_mask[:total_num_output_tokens] + torch.where( + mask.unsqueeze(1), + self.parallel_drafting_hidden_state_tensor, + self.hidden_states[:total_num_output_tokens], + out=self.hidden_states[:total_num_output_tokens], + ) + + # 2. + # Recompute the slot mapping based on the new positions and + # rejection mask. + builder = ( + self._get_attention_metadata_builder() + if self.attn_metadata_builder is None + else self.attn_metadata_builder + ) + new_slot_mapping = compute_new_slot_mapping( + cad=cad, + new_positions=self.positions[:total_num_output_tokens], + is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens], + block_size=builder.kv_cache_spec.block_size, + num_new_tokens=self.net_num_new_slots_per_request, + max_model_len=self.max_model_len, + ) + + # 3. Update the common attention metadata with the new (meta)data + new_cad = extend_all_queries_by_N( + cad, + N=self.net_num_new_slots_per_request, + arange=self.arange, + new_slot_mapping=new_slot_mapping, + ) + + return total_num_output_tokens, token_indices_to_sample, new_cad + + def model_returns_tuple(self) -> bool: + return self.method not in ("mtp", "draft_model") + def attn_update_stack_num_spec_norm( self, # `draft_step` must start from `1`, no `0` @@ -1201,7 +1380,7 @@ class AscendEagleProposer(EagleProposer): common_attn_metadata: CommonAttentionMetadata, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor, - ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, @@ -1215,7 +1394,7 @@ class AscendEagleProposer(EagleProposer): device = valid_sampled_tokens_count.device token_indices_to_sample = torch.empty((num_reqs,), dtype=torch.int32, device=device) - + num_rejected_tokens_gpu = torch.empty((num_reqs,), dtype=torch.int32, device=device) num_blocks_needed = triton.cdiv(num_reqs, _PREPARE_INPUTS_BLOCK_SIZE) num_vector_core = get_vectorcore_num() grid_size = min(num_blocks_needed, num_vector_core) @@ -1226,6 +1405,7 @@ class AscendEagleProposer(EagleProposer): valid_sampled_tokens_count, common_attn_metadata.query_start_loc, token_indices_to_sample, + num_rejected_tokens_gpu, num_reqs, BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE, ) @@ -1274,7 +1454,7 @@ class AscendEagleProposer(EagleProposer): max_seq_len=0, ) - return spec_common_attn_metadata, token_indices, token_indices_to_sample + return spec_common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu def _split_pcp_input(self, req_scheduled_tokens, input_ids, target_hidden_states): """ @@ -1394,3 +1574,18 @@ class AscendEagleProposer(EagleProposer): if hidden_states is not None: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), True) return last_hidden_states, positions, hidden_states + + +class AscendEagleProposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config, + device, + pass_hidden_states_to_model=True, + runner=runner, + ) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2c05ebbe..7631e2a9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -108,6 +108,7 @@ from vllm_ascend.patch.worker.patch_draft_quarot import patch_load_weights from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method +from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer @@ -406,7 +407,12 @@ class NPUModelRunner(GPUModelRunner): def _set_up_drafter(self): # Set up speculative decoding. self.drafter: ( - AscendNgramProposer | AscendEagleProposer | AscendSuffixDecodingProposer | AscendMedusaProposer | None + AscendNgramProposer + | AscendEagleProposer + | AscendDraftModelProposer + | AscendSuffixDecodingProposer + | AscendMedusaProposer + | None ) = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 @@ -971,7 +977,7 @@ class NPUModelRunner(GPUModelRunner): draft_token_ids = self.drafter.propose( valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states ) - elif self.speculative_config.use_eagle(): + elif self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model(): common_attn_metadata = spec_decode_common_attn_metadata sampled_token_ids = valid_sampled_token_ids @@ -1018,6 +1024,8 @@ class NPUModelRunner(GPUModelRunner): long_seq_metadata = None # type: ignore num_prefill_reqs = 0 num_decode_reqs = 0 + + num_rejected_tokens_gpu = None if spec_decode_metadata is None: # update pcp related params if self.pcp_size > 1: @@ -1053,8 +1061,10 @@ class NPUModelRunner(GPUModelRunner): ) else: assert self.drafter is not None - common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded( - common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = ( + self.drafter.prepare_inputs_padded( + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + ) ) if self.pcp_size > 1: target_token_ids = input_ids_pcp_full[token_indices] @@ -1075,7 +1085,7 @@ class NPUModelRunner(GPUModelRunner): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, + token_indices_to_sample=token_indices_to_sample, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, req_scheduled_tokens=req_scheduled_tokens, @@ -1084,6 +1094,7 @@ class NPUModelRunner(GPUModelRunner): num_decode_reqs=num_decode_reqs, scheduler_output=scheduler_output, num_scheduled_tokens=num_scheduled_tokens, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, ) else: raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}") @@ -1516,16 +1527,16 @@ class NPUModelRunner(GPUModelRunner): with record_function_or_nullcontext("draft_token"): if self.speculative_config: - use_padded_batch_for_eagle = ( + use_padded_batch = ( self.speculative_config - and self.speculative_config.use_eagle() + and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()) and not self.speculative_config.disable_padded_drafter_batch ) - if use_padded_batch_for_eagle: + if use_padded_batch: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - if self.speculative_config and not use_padded_batch_for_eagle: + if self.speculative_config and not use_padded_batch: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2165,7 +2176,7 @@ class NPUModelRunner(GPUModelRunner): if kv_cache_gid > 0: cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid) if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, AscendEagleProposer): + if isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: spec_decode_common_attn_metadata = cm else: