[feat][spec decode]Unified draft parallel (#6766)

### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078

### How was this patch tested?

run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811

benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
    --temperature 0 \
    --model /model/Llama-3.1-8B-Instruct \
    --backend openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts ${NUM_PROMPTS} \
    --max-concurrency ${MAX_CONCURRENCY} \
    --seed 1234

test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.

----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1

vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

cc  @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
kx
2026-03-13 14:07:35 +08:00
committed by GitHub
parent 6ee7ffb98a
commit df1ee8070d
18 changed files with 1943 additions and 311 deletions

View File

@@ -0,0 +1,22 @@
add_ops_compile_options(
OP_NAME CopyAndExpandEagleInputs
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnn PRIVATE
copy_and_expand_eagle_inputs_def.cpp
)
target_sources(optiling PRIVATE
copy_and_expand_eagle_inputs_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE
copy_and_expand_eagle_inputs_infershape.cpp
)

View File

@@ -0,0 +1,87 @@
/**
* @file copy_and_expand_eagle_inputs_def.cpp
* @brief CopyAndExpandEagleInputs OpDef registration
*/
#include "register/op_def_registry.h"
namespace ops {
class CopyAndExpandEagleInputs : public OpDef {
public:
explicit CopyAndExpandEagleInputs(const char* name) : OpDef(name)
{
// -------------------- Inputs --------------------
this->Input("target_token_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("target_positions")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("next_token_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("query_start_loc")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("query_end_loc")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
// -------------------- Outputs --------------------
this->Output("out_input_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_positions")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_is_rejected_token_mask")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_is_masked_token_mask")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_new_token_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_hidden_state_mapping")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
// -------------------- Attributes --------------------
this->Attr("padding_token_id").Int();
this->Attr("parallel_drafting_token_id").Int();
this->Attr("num_padding_slots_per_request").Int();
this->Attr("shift_input_ids").Bool();
this->Attr("total_input_tokens").Int();
// -------------------- Platform --------------------
this->AICore().AddConfig("ascend910b");
}
};
OP_ADD(CopyAndExpandEagleInputs);
} // namespace ops

View File

@@ -0,0 +1,107 @@
/**
* @file copy_and_expand_eagle_inputs_infershape.cpp
* @brief InferShape and InferDataType for CopyAndExpandEagleInputs
*/
#include "register/op_def_registry.h"
#include "log/ops_log.h"
#define unlikely(x) __builtin_expect((x), 0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if (unlikely((ptr) == nullptr)) { \
const char* name = (unlikely(((context) == nullptr) || (context)->GetNodeName() == nullptr)) ? \
"nil" : \
(context)->GetNodeName(); \
OPS_LOG_E(name, "%s is nullptr!", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
static constexpr int IDX_TARGET_TOKEN_IDS = 0;
static constexpr int IDX_TARGET_POSITIONS = 1;
static constexpr int IDX_NEXT_TOKEN_IDS = 2;
static constexpr int IDX_QUERY_START_LOC = 3;
static constexpr int IDX_QUERY_END_LOC = 4;
static constexpr int OUT_INPUT_IDS = 0;
static constexpr int OUT_POSITIONS = 1;
static constexpr int OUT_REJECTED_MASK = 2;
static constexpr int OUT_MASKED_MASK = 3;
static constexpr int OUT_NEW_TOKEN_INDICES = 4;
static constexpr int OUT_HIDDEN_STATE_MAPPING = 5;
static constexpr int OUTPUT_NUM = 6;
static constexpr int ATTR_NUM_PADDING_SLOTS = 2;
static constexpr int ATTR_TOTAL_INPUT_TOKENS = 4;
using namespace ge;
namespace ops {
static ge::graphStatus InferShape4CopyAndExpandEagleInputs(gert::InferShapeContext* context)
{
// Get input shapes
const gert::Shape* targetTokenIdsShape = context->GetInputShape(IDX_TARGET_TOKEN_IDS);
OP_CHECK_NULL_WITH_CONTEXT(context, targetTokenIdsShape);
const gert::Shape* queryStartLocShape = context->GetInputShape(IDX_QUERY_START_LOC);
OP_CHECK_NULL_WITH_CONTEXT(context, queryStartLocShape);
// Derive dimensions from input shapes
int64_t totalInputTokens = targetTokenIdsShape->GetDim(0);
int64_t numReqs = queryStartLocShape->GetDim(0) - 1;
// Get num_padding_slots_per_request from attributes
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
int64_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer<int64_t>(ATTR_NUM_PADDING_SLOTS));
// Compute total_draft_tokens = total_input_tokens + (num_padding_slots_per_request - 1) * num_reqs
int64_t totalDraftTokens = totalInputTokens + (numPaddingSlotsPerReq - 1) * numReqs;
// Get and validate all output shapes
gert::Shape* outShapes[OUTPUT_NUM];
for (int i = 0; i < OUTPUT_NUM; ++i) {
outShapes[i] = context->GetOutputShape(i);
OP_CHECK_NULL_WITH_CONTEXT(context, outShapes[i]);
outShapes[i]->SetDimNum(1);
}
// out_input_ids, out_positions, out_rejected_mask, out_masked_mask: [total_draft_tokens]
outShapes[OUT_INPUT_IDS]->SetDim(0, totalDraftTokens);
outShapes[OUT_POSITIONS]->SetDim(0, totalDraftTokens);
outShapes[OUT_REJECTED_MASK]->SetDim(0, totalDraftTokens);
outShapes[OUT_MASKED_MASK]->SetDim(0, totalDraftTokens);
// out_new_token_indices: [num_reqs * num_padding_slots_per_request]
outShapes[OUT_NEW_TOKEN_INDICES]->SetDim(0, numReqs * numPaddingSlotsPerReq);
// out_hidden_state_mapping: [total_input_tokens]
outShapes[OUT_HIDDEN_STATE_MAPPING]->SetDim(0, totalInputTokens);
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType4CopyAndExpandEagleInputs(gert::InferDataTypeContext* context)
{
// out_input_ids: INT32
context->SetOutputDataType(OUT_INPUT_IDS, DT_INT32);
// out_positions: INT32
context->SetOutputDataType(OUT_POSITIONS, DT_INT32);
// out_is_rejected_token_mask: INT8
context->SetOutputDataType(OUT_REJECTED_MASK, DT_INT8);
// out_is_masked_token_mask: INT8
context->SetOutputDataType(OUT_MASKED_MASK, DT_INT8);
// out_new_token_indices: INT32
context->SetOutputDataType(OUT_NEW_TOKEN_INDICES, DT_INT32);
// out_hidden_state_mapping: INT32
context->SetOutputDataType(OUT_HIDDEN_STATE_MAPPING, DT_INT32);
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(CopyAndExpandEagleInputs)
.InferShape(InferShape4CopyAndExpandEagleInputs)
.InferDataType(InferDataType4CopyAndExpandEagleInputs);
} // namespace ops

View File

@@ -0,0 +1,121 @@
/**
* @file copy_and_expand_eagle_inputs_tiling.cpp
* @brief CopyAndExpandEagleInputs TilingFunc implementation
*/
#include "copy_and_expand_eagle_inputs_tiling.h"
#include "register/op_def_registry.h"
#include "log/ops_log.h"
#include <algorithm>
namespace optiling {
static void GetCompileParameters(
gert::TilingContext* context, uint32_t& coreNum)
{
auto ptrCompileInfo = reinterpret_cast<const CopyAndExpandEagleInputsCompileInfo*>(context->GetCompileInfo());
if (ptrCompileInfo == nullptr) {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
coreNum = ascendcPlatform.GetCoreNum();
} else {
coreNum = ptrCompileInfo->totalCoreNum;
}
}
static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
OPS_LOG_I(context, "Enter TilingFunc for CopyAndExpandEagleInputs");
OPS_LOG_D(context, "TilingFunc running.");
// ========== 1. Get hardware core count ==========
uint32_t coreNum;
GetCompileParameters(context, coreNum);
// ========== 2. Derive num_reqs from query_start_loc shape ==========
// query_start_loc is the 4th input (index 3), shape [num_reqs + 1]
auto queryStartLocShape = context->GetInputShape(3);
uint32_t numReqs = 0;
if (queryStartLocShape != nullptr &&
queryStartLocShape->GetStorageShape().GetDimNum() > 0) {
int64_t dim0 = queryStartLocShape->GetStorageShape().GetDim(0);
numReqs = (dim0 > 1) ? static_cast<uint32_t>(dim0 - 1) : 0;
}
// ========== 3. Get operator attributes ==========
auto attrs = context->GetAttrs();
int32_t paddingTokenId = *(attrs->GetAttrPointer<int32_t>(0));
int32_t parallelDraftingTokenId = *(attrs->GetAttrPointer<int32_t>(1));
int32_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer<int32_t>(2));
bool shiftInputIds = *(attrs->GetAttrPointer<bool>(3));
int32_t totalInputTokens = *(attrs->GetAttrPointer<int32_t>(4));
// ========== 4. Compute core distribution ==========
uint32_t usedCoreNum = std::min(coreNum, numReqs);
if (usedCoreNum == 0) {
usedCoreNum = 1;
}
uint32_t reqsPerCore = numReqs / usedCoreNum;
uint32_t remainderReqs = numReqs % usedCoreNum;
// ========== 5. Set tiling_key ==========
context->SetTilingKey(1);
// ========== 6. Get output shape ==========
uint32_t totalDraftTokens = 0;
auto outShape = context->GetOutputShape(0);
if (outShape != nullptr &&
outShape->GetStorageShape().GetDimNum() > 0) {
totalDraftTokens = static_cast<uint32_t>(outShape->GetStorageShape().GetDim(0));
}
// ========== 7. Fill TilingData ==========
CopyAndExpandEagleInputsTilingData tiling;
tiling.set_usedCoreNum(usedCoreNum);
tiling.set_numReqs(numReqs);
tiling.set_reqsPerCore(reqsPerCore);
tiling.set_remainderReqs(remainderReqs);
tiling.set_paddingTokenId(paddingTokenId);
tiling.set_parallelDraftingTokenId(parallelDraftingTokenId);
tiling.set_numPaddingSlotsPerReq(static_cast<uint32_t>(numPaddingSlotsPerReq));
tiling.set_totalInputTokens(static_cast<uint32_t>(totalInputTokens));
tiling.set_shiftInputIds(shiftInputIds ? 1u : 0u);
tiling.set_totalDraftTokens(totalDraftTokens);
tiling.SaveToBuffer(
context->GetRawTilingData()->GetData(),
context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
// ========== 8. Set block_dim ==========
context->SetBlockDim(usedCoreNum);
OPS_LOG_I(context, "Block Dim: %u", usedCoreNum);
OPS_LOG_I(context,
"numReqs: %u, reqsPerCore: %u, remainderReqs: %u, totalInputTokens: %d, totalDraftTokens: %u",
numReqs, reqsPerCore, remainderReqs, totalInputTokens, totalDraftTokens);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4CopyAndExpandEagleInputs(gert::TilingParseContext* context)
{
OPS_LOG_D(context, "TilingPrepare4CopyAndExpandEagleInputs running.");
OPS_LOG_I(context, "TilingPrepare4CopyAndExpandEagleInputs running.");
auto compileInfo = context->GetCompiledInfo<CopyAndExpandEagleInputsCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNum();
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(CopyAndExpandEagleInputs)
.Tiling(TilingFunc)
.TilingParse<CopyAndExpandEagleInputsCompileInfo>(TilingPrepare4CopyAndExpandEagleInputs);
} // namespace optiling

View File

@@ -0,0 +1,37 @@
#ifndef COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H
#define COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H
#include "register/tilingdata_base.h"
#include "error_log.h"
#include "register/op_impl_registry.h"
#include "tiling/platform/platform_ascendc.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(CopyAndExpandEagleInputsTilingData)
// ---- 分核参数 ----
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum); // 实际使用的核数
TILING_DATA_FIELD_DEF(uint32_t, numReqs); // 总请求数
TILING_DATA_FIELD_DEF(uint32_t, reqsPerCore); // 每核基础请求数
TILING_DATA_FIELD_DEF(uint32_t, remainderReqs); // 余数(前 remainder 个核多处理 1 个请求)
// ---- 算子属性 ----
TILING_DATA_FIELD_DEF(int32_t, paddingTokenId); // 填充 token ID
TILING_DATA_FIELD_DEF(int32_t, parallelDraftingTokenId); // 并行推测解码 token ID
TILING_DATA_FIELD_DEF(uint32_t, numPaddingSlotsPerReq); // 每个请求的 padding 槽位数
TILING_DATA_FIELD_DEF(uint32_t, totalInputTokens); // 输入 token 总数(用于 clamp
TILING_DATA_FIELD_DEF(uint32_t, shiftInputIds); // 0 = false, 1 = true
// ---- 输出尺寸 ----
TILING_DATA_FIELD_DEF(uint32_t, totalDraftTokens); // 输出 token 总数
END_TILING_DATA_DEF;
struct CopyAndExpandEagleInputsCompileInfo {
uint32_t totalCoreNum = 0;
};
REGISTER_TILING_DATA_CLASS(CopyAndExpandEagleInputs, CopyAndExpandEagleInputsTilingData)
} // namespace optiling
#endif // COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H