[feat][spec decode]Unified draft parallel (#6766)

### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078

### How was this patch tested?

run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811

benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
    --temperature 0 \
    --model /model/Llama-3.1-8B-Instruct \
    --backend openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts ${NUM_PROMPTS} \
    --max-concurrency ${MAX_CONCURRENCY} \
    --seed 1234

test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.

----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1

vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

cc  @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
kx
2026-03-13 14:07:35 +08:00
committed by GitHub
parent 6ee7ffb98a
commit df1ee8070d
18 changed files with 1943 additions and 311 deletions

View File

@@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;causal_conv1d;" CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;"
SOC_ARG="ascend910b" SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series # ASCEND910C (A3) series
@@ -64,6 +64,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
"add_rms_norm_bias" "add_rms_norm_bias"
"apply_top_k_top_p_custom" "apply_top_k_top_p_custom"
"transpose_kv_cache_by_block" "transpose_kv_cache_by_block"
"copy_and_expand_eagle_inputs"
"causal_conv1d" "causal_conv1d"
"moe_grouped_matmul" "moe_grouped_matmul"
) )

View File

@@ -0,0 +1,22 @@
add_ops_compile_options(
OP_NAME CopyAndExpandEagleInputs
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
)
target_sources(op_host_aclnn PRIVATE
copy_and_expand_eagle_inputs_def.cpp
)
target_sources(optiling PRIVATE
copy_and_expand_eagle_inputs_tiling.cpp
)
target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_sources(opsproto PRIVATE
copy_and_expand_eagle_inputs_infershape.cpp
)

View File

@@ -0,0 +1,87 @@
/**
* @file copy_and_expand_eagle_inputs_def.cpp
* @brief CopyAndExpandEagleInputs OpDef registration
*/
#include "register/op_def_registry.h"
namespace ops {
class CopyAndExpandEagleInputs : public OpDef {
public:
explicit CopyAndExpandEagleInputs(const char* name) : OpDef(name)
{
// -------------------- Inputs --------------------
this->Input("target_token_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("target_positions")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("next_token_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("query_start_loc")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Input("query_end_loc")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
// -------------------- Outputs --------------------
this->Output("out_input_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_positions")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_is_rejected_token_mask")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_is_masked_token_mask")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_new_token_indices")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("out_hidden_state_mapping")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
// -------------------- Attributes --------------------
this->Attr("padding_token_id").Int();
this->Attr("parallel_drafting_token_id").Int();
this->Attr("num_padding_slots_per_request").Int();
this->Attr("shift_input_ids").Bool();
this->Attr("total_input_tokens").Int();
// -------------------- Platform --------------------
this->AICore().AddConfig("ascend910b");
}
};
OP_ADD(CopyAndExpandEagleInputs);
} // namespace ops

View File

@@ -0,0 +1,107 @@
/**
* @file copy_and_expand_eagle_inputs_infershape.cpp
* @brief InferShape and InferDataType for CopyAndExpandEagleInputs
*/
#include "register/op_def_registry.h"
#include "log/ops_log.h"
#define unlikely(x) __builtin_expect((x), 0)
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
do { \
if (unlikely((ptr) == nullptr)) { \
const char* name = (unlikely(((context) == nullptr) || (context)->GetNodeName() == nullptr)) ? \
"nil" : \
(context)->GetNodeName(); \
OPS_LOG_E(name, "%s is nullptr!", #ptr); \
return ge::GRAPH_FAILED; \
} \
} while (0)
static constexpr int IDX_TARGET_TOKEN_IDS = 0;
static constexpr int IDX_TARGET_POSITIONS = 1;
static constexpr int IDX_NEXT_TOKEN_IDS = 2;
static constexpr int IDX_QUERY_START_LOC = 3;
static constexpr int IDX_QUERY_END_LOC = 4;
static constexpr int OUT_INPUT_IDS = 0;
static constexpr int OUT_POSITIONS = 1;
static constexpr int OUT_REJECTED_MASK = 2;
static constexpr int OUT_MASKED_MASK = 3;
static constexpr int OUT_NEW_TOKEN_INDICES = 4;
static constexpr int OUT_HIDDEN_STATE_MAPPING = 5;
static constexpr int OUTPUT_NUM = 6;
static constexpr int ATTR_NUM_PADDING_SLOTS = 2;
static constexpr int ATTR_TOTAL_INPUT_TOKENS = 4;
using namespace ge;
namespace ops {
static ge::graphStatus InferShape4CopyAndExpandEagleInputs(gert::InferShapeContext* context)
{
// Get input shapes
const gert::Shape* targetTokenIdsShape = context->GetInputShape(IDX_TARGET_TOKEN_IDS);
OP_CHECK_NULL_WITH_CONTEXT(context, targetTokenIdsShape);
const gert::Shape* queryStartLocShape = context->GetInputShape(IDX_QUERY_START_LOC);
OP_CHECK_NULL_WITH_CONTEXT(context, queryStartLocShape);
// Derive dimensions from input shapes
int64_t totalInputTokens = targetTokenIdsShape->GetDim(0);
int64_t numReqs = queryStartLocShape->GetDim(0) - 1;
// Get num_padding_slots_per_request from attributes
auto attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
int64_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer<int64_t>(ATTR_NUM_PADDING_SLOTS));
// Compute total_draft_tokens = total_input_tokens + (num_padding_slots_per_request - 1) * num_reqs
int64_t totalDraftTokens = totalInputTokens + (numPaddingSlotsPerReq - 1) * numReqs;
// Get and validate all output shapes
gert::Shape* outShapes[OUTPUT_NUM];
for (int i = 0; i < OUTPUT_NUM; ++i) {
outShapes[i] = context->GetOutputShape(i);
OP_CHECK_NULL_WITH_CONTEXT(context, outShapes[i]);
outShapes[i]->SetDimNum(1);
}
// out_input_ids, out_positions, out_rejected_mask, out_masked_mask: [total_draft_tokens]
outShapes[OUT_INPUT_IDS]->SetDim(0, totalDraftTokens);
outShapes[OUT_POSITIONS]->SetDim(0, totalDraftTokens);
outShapes[OUT_REJECTED_MASK]->SetDim(0, totalDraftTokens);
outShapes[OUT_MASKED_MASK]->SetDim(0, totalDraftTokens);
// out_new_token_indices: [num_reqs * num_padding_slots_per_request]
outShapes[OUT_NEW_TOKEN_INDICES]->SetDim(0, numReqs * numPaddingSlotsPerReq);
// out_hidden_state_mapping: [total_input_tokens]
outShapes[OUT_HIDDEN_STATE_MAPPING]->SetDim(0, totalInputTokens);
return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType4CopyAndExpandEagleInputs(gert::InferDataTypeContext* context)
{
// out_input_ids: INT32
context->SetOutputDataType(OUT_INPUT_IDS, DT_INT32);
// out_positions: INT32
context->SetOutputDataType(OUT_POSITIONS, DT_INT32);
// out_is_rejected_token_mask: INT8
context->SetOutputDataType(OUT_REJECTED_MASK, DT_INT8);
// out_is_masked_token_mask: INT8
context->SetOutputDataType(OUT_MASKED_MASK, DT_INT8);
// out_new_token_indices: INT32
context->SetOutputDataType(OUT_NEW_TOKEN_INDICES, DT_INT32);
// out_hidden_state_mapping: INT32
context->SetOutputDataType(OUT_HIDDEN_STATE_MAPPING, DT_INT32);
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(CopyAndExpandEagleInputs)
.InferShape(InferShape4CopyAndExpandEagleInputs)
.InferDataType(InferDataType4CopyAndExpandEagleInputs);
} // namespace ops

View File

@@ -0,0 +1,121 @@
/**
* @file copy_and_expand_eagle_inputs_tiling.cpp
* @brief CopyAndExpandEagleInputs TilingFunc implementation
*/
#include "copy_and_expand_eagle_inputs_tiling.h"
#include "register/op_def_registry.h"
#include "log/ops_log.h"
#include <algorithm>
namespace optiling {
static void GetCompileParameters(
gert::TilingContext* context, uint32_t& coreNum)
{
auto ptrCompileInfo = reinterpret_cast<const CopyAndExpandEagleInputsCompileInfo*>(context->GetCompileInfo());
if (ptrCompileInfo == nullptr) {
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
coreNum = ascendcPlatform.GetCoreNum();
} else {
coreNum = ptrCompileInfo->totalCoreNum;
}
}
static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
OPS_LOG_I(context, "Enter TilingFunc for CopyAndExpandEagleInputs");
OPS_LOG_D(context, "TilingFunc running.");
// ========== 1. Get hardware core count ==========
uint32_t coreNum;
GetCompileParameters(context, coreNum);
// ========== 2. Derive num_reqs from query_start_loc shape ==========
// query_start_loc is the 4th input (index 3), shape [num_reqs + 1]
auto queryStartLocShape = context->GetInputShape(3);
uint32_t numReqs = 0;
if (queryStartLocShape != nullptr &&
queryStartLocShape->GetStorageShape().GetDimNum() > 0) {
int64_t dim0 = queryStartLocShape->GetStorageShape().GetDim(0);
numReqs = (dim0 > 1) ? static_cast<uint32_t>(dim0 - 1) : 0;
}
// ========== 3. Get operator attributes ==========
auto attrs = context->GetAttrs();
int32_t paddingTokenId = *(attrs->GetAttrPointer<int32_t>(0));
int32_t parallelDraftingTokenId = *(attrs->GetAttrPointer<int32_t>(1));
int32_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer<int32_t>(2));
bool shiftInputIds = *(attrs->GetAttrPointer<bool>(3));
int32_t totalInputTokens = *(attrs->GetAttrPointer<int32_t>(4));
// ========== 4. Compute core distribution ==========
uint32_t usedCoreNum = std::min(coreNum, numReqs);
if (usedCoreNum == 0) {
usedCoreNum = 1;
}
uint32_t reqsPerCore = numReqs / usedCoreNum;
uint32_t remainderReqs = numReqs % usedCoreNum;
// ========== 5. Set tiling_key ==========
context->SetTilingKey(1);
// ========== 6. Get output shape ==========
uint32_t totalDraftTokens = 0;
auto outShape = context->GetOutputShape(0);
if (outShape != nullptr &&
outShape->GetStorageShape().GetDimNum() > 0) {
totalDraftTokens = static_cast<uint32_t>(outShape->GetStorageShape().GetDim(0));
}
// ========== 7. Fill TilingData ==========
CopyAndExpandEagleInputsTilingData tiling;
tiling.set_usedCoreNum(usedCoreNum);
tiling.set_numReqs(numReqs);
tiling.set_reqsPerCore(reqsPerCore);
tiling.set_remainderReqs(remainderReqs);
tiling.set_paddingTokenId(paddingTokenId);
tiling.set_parallelDraftingTokenId(parallelDraftingTokenId);
tiling.set_numPaddingSlotsPerReq(static_cast<uint32_t>(numPaddingSlotsPerReq));
tiling.set_totalInputTokens(static_cast<uint32_t>(totalInputTokens));
tiling.set_shiftInputIds(shiftInputIds ? 1u : 0u);
tiling.set_totalDraftTokens(totalDraftTokens);
tiling.SaveToBuffer(
context->GetRawTilingData()->GetData(),
context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
// ========== 8. Set block_dim ==========
context->SetBlockDim(usedCoreNum);
OPS_LOG_I(context, "Block Dim: %u", usedCoreNum);
OPS_LOG_I(context,
"numReqs: %u, reqsPerCore: %u, remainderReqs: %u, totalInputTokens: %d, totalDraftTokens: %u",
numReqs, reqsPerCore, remainderReqs, totalInputTokens, totalDraftTokens);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus TilingPrepare4CopyAndExpandEagleInputs(gert::TilingParseContext* context)
{
OPS_LOG_D(context, "TilingPrepare4CopyAndExpandEagleInputs running.");
OPS_LOG_I(context, "TilingPrepare4CopyAndExpandEagleInputs running.");
auto compileInfo = context->GetCompiledInfo<CopyAndExpandEagleInputsCompileInfo>();
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
auto platformInfo = context->GetPlatformInfo();
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNum();
return ge::GRAPH_SUCCESS;
}
IMPL_OP_OPTILING(CopyAndExpandEagleInputs)
.Tiling(TilingFunc)
.TilingParse<CopyAndExpandEagleInputsCompileInfo>(TilingPrepare4CopyAndExpandEagleInputs);
} // namespace optiling

View File

@@ -0,0 +1,37 @@
#ifndef COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H
#define COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H
#include "register/tilingdata_base.h"
#include "error_log.h"
#include "register/op_impl_registry.h"
#include "tiling/platform/platform_ascendc.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(CopyAndExpandEagleInputsTilingData)
// ---- 分核参数 ----
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum); // 实际使用的核数
TILING_DATA_FIELD_DEF(uint32_t, numReqs); // 总请求数
TILING_DATA_FIELD_DEF(uint32_t, reqsPerCore); // 每核基础请求数
TILING_DATA_FIELD_DEF(uint32_t, remainderReqs); // 余数(前 remainder 个核多处理 1 个请求)
// ---- 算子属性 ----
TILING_DATA_FIELD_DEF(int32_t, paddingTokenId); // 填充 token ID
TILING_DATA_FIELD_DEF(int32_t, parallelDraftingTokenId); // 并行推测解码 token ID
TILING_DATA_FIELD_DEF(uint32_t, numPaddingSlotsPerReq); // 每个请求的 padding 槽位数
TILING_DATA_FIELD_DEF(uint32_t, totalInputTokens); // 输入 token 总数(用于 clamp
TILING_DATA_FIELD_DEF(uint32_t, shiftInputIds); // 0 = false, 1 = true
// ---- 输出尺寸 ----
TILING_DATA_FIELD_DEF(uint32_t, totalDraftTokens); // 输出 token 总数
END_TILING_DATA_DEF;
struct CopyAndExpandEagleInputsCompileInfo {
uint32_t totalCoreNum = 0;
};
REGISTER_TILING_DATA_CLASS(CopyAndExpandEagleInputs, CopyAndExpandEagleInputsTilingData)
} // namespace optiling
#endif // COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H

View File

@@ -0,0 +1,386 @@
/**
* CopyAndExpandEagleInputs 算子 Kernel 实现 (DataCopy 版)
*
* 多核策略:
* 所有 GM 读写通过 DataCopy 完成(不使用 GlobalTensor::SetValue/GetValue 访问 GM
* UB (LocalTensor) 上使用 SetValue/GetValue 构建数据,再 DataCopy 到 GM。
* 对齐处理参考 CANN 内置算子的 DataCopyCustom 模式。
*/
#include "kernel_operator.h"
using namespace AscendC;
// ONE_BLK_SIZE comes from AscendC namespace (32 bytes per block)
class CopyAndExpandEagleInputsKernel {
public:
__aicore__ inline CopyAndExpandEagleInputsKernel() {}
__aicore__ inline void Init(GM_ADDR targetTokenIds, GM_ADDR targetPositions,
GM_ADDR nextTokenIds, GM_ADDR queryStartLoc,
GM_ADDR queryEndLoc,
GM_ADDR outInputIds, GM_ADDR outPositions,
GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask,
GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping,
const CopyAndExpandEagleInputsTilingData* tilingData)
{
usedCoreNum = tilingData->usedCoreNum;
numReqs = tilingData->numReqs;
reqsPerCore = tilingData->reqsPerCore;
remainderReqs = tilingData->remainderReqs;
paddingTokenId = tilingData->paddingTokenId;
parallelDraftingTokenId = tilingData->parallelDraftingTokenId;
numPaddingSlotsPerReq = tilingData->numPaddingSlotsPerReq;
totalInputTokens = tilingData->totalInputTokens;
totalDraftTokens = tilingData->totalDraftTokens;
uint32_t coreId = GetBlockIdx();
if (coreId < remainderReqs) {
myStartReq = coreId * (reqsPerCore + 1);
myNumReqs = reqsPerCore + 1;
} else {
myStartReq = remainderReqs * (reqsPerCore + 1) + (coreId - remainderReqs) * reqsPerCore;
myNumReqs = reqsPerCore;
}
// 绑定 GM Tensor
gmTargetTokenIds.SetGlobalBuffer((__gm__ int32_t*)targetTokenIds, totalInputTokens);
gmTargetPositions.SetGlobalBuffer((__gm__ int32_t*)targetPositions, totalInputTokens);
gmNextTokenIds.SetGlobalBuffer((__gm__ int32_t*)nextTokenIds, numReqs);
gmQueryStartLoc.SetGlobalBuffer((__gm__ int32_t*)queryStartLoc, numReqs + 1);
gmQueryEndLoc.SetGlobalBuffer((__gm__ int32_t*)queryEndLoc, numReqs);
gmOutInputIds.SetGlobalBuffer((__gm__ int32_t*)outInputIds, totalDraftTokens);
gmOutPositions.SetGlobalBuffer((__gm__ int32_t*)outPositions, totalDraftTokens);
gmOutIsRejectedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsRejectedTokenMask, totalDraftTokens);
gmOutIsMaskedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsMaskedTokenMask, totalDraftTokens);
gmOutNewTokenIndices.SetGlobalBuffer((__gm__ int32_t*)outNewTokenIndices, numPaddingSlotsPerReq * numReqs);
gmOutHiddenStateMapping.SetGlobalBuffer((__gm__ int32_t*)outHiddenStateMapping, totalInputTokens);
// 分配 UB 缓冲区 —— 每个 TBuf 的基地址自动 32 字节对齐
// 元数据各自独立 TBuf避免 UB 地址不对齐
uint32_t metaAligned = AlignUp((myNumReqs + 1) * sizeof(int32_t), ONE_BLK_SIZE);
pipe.InitBuffer(qsBuf, metaAligned);
pipe.InitBuffer(qeBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE));
pipe.InitBuffer(ntBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE));
// I/O 缓冲区
constexpr uint32_t MAX_PER_REQ = 4096;
pipe.InitBuffer(inputBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
pipe.InitBuffer(outIdsBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
pipe.InitBuffer(outPosBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
pipe.InitBuffer(outRejBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE));
pipe.InitBuffer(outMskBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE));
pipe.InitBuffer(ntiBuf, AlignUp(64 * sizeof(int32_t), ONE_BLK_SIZE));
pipe.InitBuffer(hsmBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
// DataCopy 元数据到各自 UB
if (myNumReqs > 0) {
LocalTensor<int32_t> lqs = qsBuf.Get<int32_t>();
DataCopyIn(lqs, gmQueryStartLoc, (int32_t)myStartReq, (int32_t)(myNumReqs + 1));
LocalTensor<int32_t> lqe = qeBuf.Get<int32_t>();
DataCopyIn(lqe, gmQueryEndLoc, (int32_t)myStartReq, (int32_t)myNumReqs);
LocalTensor<int32_t> lnt = ntBuf.Get<int32_t>();
DataCopyIn(lnt, gmNextTokenIds, (int32_t)myStartReq, (int32_t)myNumReqs);
}
}
__aicore__ inline void ProcessShiftFalse()
{
for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) {
ProcessOneRequestShiftFalse(myStartReq + rLocal, rLocal);
}
}
__aicore__ inline void ProcessShiftTrue()
{
for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) {
ProcessOneRequestShiftTrue(myStartReq + rLocal, rLocal);
}
}
private:
// ============================================================
// AlignUp 辅助
// ============================================================
static __aicore__ inline uint32_t AlignUp(uint32_t x, uint32_t a)
{
return (x + a - 1) / a * a;
}
// ============================================================
// GM → UB: 标准 DataCopycount 自动 round-up 到 block 对齐
// 多读的元素在 UB 中不会被使用,安全无害
// ============================================================
__aicore__ inline void DataCopyIn(LocalTensor<int32_t>& dst,
GlobalTensor<int32_t>& src,
int32_t gmOffset, int32_t count)
{
if (count <= 0) return;
constexpr int32_t ELEMS_PER_BLK = ONE_BLK_SIZE / (int32_t)sizeof(int32_t); // 8
int32_t aligned = (count + ELEMS_PER_BLK - 1) / ELEMS_PER_BLK * ELEMS_PER_BLK;
DataCopy(dst, src[gmOffset], aligned);
pipe_barrier(PIPE_ALL);
}
// ============================================================
// UB → GM: DataCopyPad + DataCopyExtParamsC220 支持任意字节数)
// 精确写入 count 个元素,不越界覆盖相邻数据
// ============================================================
__aicore__ inline void DataCopyOut_int32(GlobalTensor<int32_t>& dst,
LocalTensor<int32_t>& src,
int32_t gmOffset, int32_t count)
{
if (count <= 0) return;
uint32_t totalBytes = static_cast<uint32_t>(count) * static_cast<uint32_t>(sizeof(int32_t));
pipe_barrier(PIPE_ALL);
DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0));
pipe_barrier(PIPE_ALL);
}
__aicore__ inline void DataCopyOut_int8(GlobalTensor<int8_t>& dst,
LocalTensor<int8_t>& src,
int32_t gmOffset, int32_t count)
{
if (count <= 0) return;
uint32_t totalBytes = static_cast<uint32_t>(count) * static_cast<uint32_t>(sizeof(int8_t));
pipe_barrier(PIPE_ALL);
DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0));
pipe_barrier(PIPE_ALL);
}
// ============================================================
// 元数据读取 (从各自 UB 缓冲区)
// ============================================================
__aicore__ inline int32_t ReadQS(uint32_t rLocal) {
return qsBuf.Get<int32_t>().GetValue(rLocal);
}
__aicore__ inline int32_t ReadNextQS(uint32_t rLocal) {
return qsBuf.Get<int32_t>().GetValue(rLocal + 1);
}
__aicore__ inline int32_t ReadQE(uint32_t rLocal) {
return qeBuf.Get<int32_t>().GetValue(rLocal);
}
__aicore__ inline int32_t ReadNT(uint32_t rLocal) {
return ntBuf.Get<int32_t>().GetValue(rLocal);
}
// ============================================================
// shift_input_ids = false
// ============================================================
__aicore__ inline void ProcessOneRequestShiftFalse(uint32_t r, uint32_t rLocal)
{
int32_t queryStart = ReadQS(rLocal);
int32_t nextQueryStart = ReadNextQS(rLocal);
int32_t queryEnd = ReadQE(rLocal);
int32_t numRejected = nextQueryStart - queryEnd - 1;
if (numRejected < 0) numRejected = 0;
int32_t numValid = queryEnd - queryStart + 1;
if (numValid < 0) numValid = 0;
int32_t outputStart = queryStart + (int32_t)r * (int32_t)numPaddingSlotsPerReq;
int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected;
// 读取输入 token 到 UB
int32_t numInputTokensForReq = nextQueryStart - queryStart;
LocalTensor<int32_t> localInput = inputBuf.Get<int32_t>();
if (numInputTokensForReq > 0) {
DataCopyIn(localInput, gmTargetTokenIds, queryStart, numInputTokensForReq);
}
// 读取起始 position
LocalTensor<int32_t> localTmpPos = hsmBuf.Get<int32_t>();
DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1);
int32_t startPos = localTmpPos.GetValue(0);
int32_t nextTokenId = ReadNT(rLocal);
// 构建输出到 UB
LocalTensor<int32_t> lIds = outIdsBuf.Get<int32_t>();
LocalTensor<int32_t> lPos = outPosBuf.Get<int32_t>();
LocalTensor<int8_t> lRej = outRejBuf.Get<int8_t>();
LocalTensor<int8_t> lMsk = outMskBuf.Get<int8_t>();
for (int32_t j = 0; j < numValid; j++) {
int32_t inIdx = j;
if (inIdx >= numInputTokensForReq) inIdx = numInputTokensForReq - 1;
lIds.SetValue(j, localInput.GetValue(inIdx));
lPos.SetValue(j, startPos + j);
lRej.SetValue(j, (int8_t)0);
lMsk.SetValue(j, (int8_t)0);
}
// Bonus
lIds.SetValue(numValid, nextTokenId);
lPos.SetValue(numValid, startPos + numValid);
lRej.SetValue(numValid, (int8_t)0);
lMsk.SetValue(numValid, (int8_t)0);
// Parallel Draft
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
int32_t j = numValid + k;
lIds.SetValue(j, parallelDraftingTokenId);
lPos.SetValue(j, startPos + j);
lRej.SetValue(j, (int8_t)0);
lMsk.SetValue(j, (int8_t)1);
}
// Rejected
for (int32_t k = 0; k < numRejected; k++) {
int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k;
lIds.SetValue(j, paddingTokenId);
lPos.SetValue(j, (int32_t)0);
lRej.SetValue(j, (int8_t)1);
lMsk.SetValue(j, (int8_t)0);
}
// UB → GM
DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen);
DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen);
DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen);
DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen);
// NTI
LocalTensor<int32_t> lNti = ntiBuf.Get<int32_t>();
lNti.SetValue(0, outputStart + numValid);
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
lNti.SetValue(k, outputStart + numValid + k);
}
int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq;
DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq);
}
// ============================================================
// shift_input_ids = true
// ============================================================
__aicore__ inline void ProcessOneRequestShiftTrue(uint32_t r, uint32_t rLocal)
{
int32_t queryStart = ReadQS(rLocal);
int32_t nextQueryStart = ReadNextQS(rLocal);
int32_t queryEnd = ReadQE(rLocal);
int32_t numRejected = nextQueryStart - queryEnd - 1;
if (numRejected < 0) numRejected = 0;
int32_t numValid = queryEnd - queryStart;
if (numValid < 0) numValid = 0;
int32_t outputStart = queryStart + (int32_t)r * ((int32_t)numPaddingSlotsPerReq - 1);
int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected;
int32_t numInputTokensForReq = nextQueryStart - queryStart;
LocalTensor<int32_t> localInput = inputBuf.Get<int32_t>();
int32_t readStart = queryStart + 1;
int32_t readCount = numValid;
if (readStart + readCount > (int32_t)totalInputTokens) {
readCount = (int32_t)totalInputTokens - readStart;
if (readCount < 0) readCount = 0;
}
if (readCount > 0) {
DataCopyIn(localInput, gmTargetTokenIds, readStart, readCount);
}
LocalTensor<int32_t> localTmpPos = hsmBuf.Get<int32_t>();
DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1);
int32_t startPos = localTmpPos.GetValue(0);
int32_t nextTokenId = ReadNT(rLocal);
LocalTensor<int32_t> lIds = outIdsBuf.Get<int32_t>();
LocalTensor<int32_t> lPos = outPosBuf.Get<int32_t>();
LocalTensor<int8_t> lRej = outRejBuf.Get<int8_t>();
LocalTensor<int8_t> lMsk = outMskBuf.Get<int8_t>();
for (int32_t j = 0; j < numValid; j++) {
int32_t inIdx = j;
if (inIdx >= readCount && readCount > 0) inIdx = readCount - 1;
lIds.SetValue(j, readCount > 0 ? localInput.GetValue(inIdx) : (int32_t)0);
lPos.SetValue(j, startPos + j);
lRej.SetValue(j, (int8_t)0);
lMsk.SetValue(j, (int8_t)0);
}
lIds.SetValue(numValid, nextTokenId);
lPos.SetValue(numValid, startPos + numValid);
lRej.SetValue(numValid, (int8_t)0);
lMsk.SetValue(numValid, (int8_t)0);
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
int32_t j = numValid + k;
lIds.SetValue(j, parallelDraftingTokenId);
lPos.SetValue(j, startPos + j);
lRej.SetValue(j, (int8_t)0);
lMsk.SetValue(j, (int8_t)1);
}
for (int32_t k = 0; k < numRejected; k++) {
int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k;
lIds.SetValue(j, paddingTokenId);
lPos.SetValue(j, (int32_t)0);
lRej.SetValue(j, (int8_t)1);
lMsk.SetValue(j, (int8_t)0);
}
DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen);
DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen);
DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen);
DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen);
LocalTensor<int32_t> lNti = ntiBuf.Get<int32_t>();
lNti.SetValue(0, outputStart + numValid);
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
lNti.SetValue(k, outputStart + numValid + k);
}
int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq;
DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq);
// hidden_state_mapping
LocalTensor<int32_t> lHsm = hsmBuf.Get<int32_t>();
for (int32_t j = 0; j < numInputTokensForReq; j++) {
lHsm.SetValue(j, outputStart + j);
}
DataCopyOut_int32(gmOutHiddenStateMapping, lHsm, queryStart, numInputTokensForReq);
}
private:
GlobalTensor<int32_t> gmTargetTokenIds, gmTargetPositions, gmNextTokenIds;
GlobalTensor<int32_t> gmQueryStartLoc, gmQueryEndLoc;
GlobalTensor<int32_t> gmOutInputIds, gmOutPositions;
GlobalTensor<int8_t> gmOutIsRejectedTokenMask, gmOutIsMaskedTokenMask;
GlobalTensor<int32_t> gmOutNewTokenIndices, gmOutHiddenStateMapping;
uint32_t usedCoreNum, numReqs, reqsPerCore, remainderReqs;
int32_t paddingTokenId, parallelDraftingTokenId;
uint32_t numPaddingSlotsPerReq, totalInputTokens, totalDraftTokens;
uint32_t myStartReq, myNumReqs;
TPipe pipe;
TBuf<QuePosition::VECCALC> qsBuf, qeBuf, ntBuf;
TBuf<QuePosition::VECCALC> inputBuf, outIdsBuf, outPosBuf;
TBuf<QuePosition::VECCALC> outRejBuf, outMskBuf, ntiBuf, hsmBuf;
};
extern "C" __global__ __aicore__ void copy_and_expand_eagle_inputs(
GM_ADDR targetTokenIds, GM_ADDR targetPositions,
GM_ADDR nextTokenIds, GM_ADDR queryStartLoc,
GM_ADDR queryEndLoc,
GM_ADDR outInputIds, GM_ADDR outPositions,
GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask,
GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping,
GM_ADDR workspace, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
if (GetBlockIdx() >= tilingData.usedCoreNum) {
return;
}
if (TILING_KEY_IS(1)) {
CopyAndExpandEagleInputsKernel op;
op.Init(targetTokenIds, targetPositions, nextTokenIds, queryStartLoc, queryEndLoc,
outInputIds, outPositions, outIsRejectedTokenMask, outIsMaskedTokenMask,
outNewTokenIndices, outHiddenStateMapping, &tilingData);
if (tilingData.shiftInputIds == 0) {
op.ProcessShiftFalse();
} else {
op.ProcessShiftTrue();
}
}
}

View File

@@ -597,6 +597,41 @@ void transpose_kv_cache_by_block(
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
npu_copy_and_expand_eagle_inputs(
const at::Tensor &target_token_ids,
const at::Tensor &target_positions,
const at::Tensor &next_token_ids,
const at::Tensor &query_start_loc,
const at::Tensor &query_end_loc,
int64_t padding_token_id,
int64_t parallel_drafting_token_id,
int64_t num_padding_slots_per_request,
bool shift_input_ids,
int64_t total_draft_tokens)
{
int64_t total_input_tokens = target_token_ids.size(0);
int64_t num_reqs = query_start_loc.size(0) - 1;
auto device = target_token_ids.device();
at::Tensor out_input_ids = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device));
at::Tensor out_positions = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device));
at::Tensor out_is_rejected_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device));
at::Tensor out_is_masked_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device));
at::Tensor out_new_token_indices = at::empty({num_reqs * num_padding_slots_per_request}, at::dtype(at::kInt).device(device));
at::Tensor out_hidden_state_mapping = at::empty({total_input_tokens}, at::dtype(at::kInt).device(device));
EXEC_NPU_CMD(aclnnCopyAndExpandEagleInputs,
target_token_ids, target_positions, next_token_ids, query_start_loc, query_end_loc,
padding_token_id, parallel_drafting_token_id, num_padding_slots_per_request,
shift_input_ids, total_input_tokens,
out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
out_new_token_indices, out_hidden_state_mapping);
return {out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
out_new_token_indices, out_hidden_state_mapping};
}
at::Tensor causal_conv1d_fn( at::Tensor causal_conv1d_fn(
const at::Tensor& mixed_qkv_non_spec_T, const at::Tensor& mixed_qkv_non_spec_T,
const at::Tensor& conv_weights, const at::Tensor& conv_weights,
@@ -849,6 +884,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()" "transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
); );
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block); ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
ops.def(
"npu_copy_and_expand_eagle_inputs(Tensor target_token_ids, Tensor target_positions, "
"Tensor next_token_ids, Tensor query_start_loc, Tensor query_end_loc, "
"int padding_token_id, int parallel_drafting_token_id, int num_padding_slots_per_request, "
"bool shift_input_ids, int total_draft_tokens) -> "
"(Tensor out_input_ids, Tensor out_positions, Tensor out_is_rejected_token_mask, "
"Tensor out_is_masked_token_mask, Tensor out_new_token_indices, Tensor out_hidden_state_mapping)"
);
ops.impl("npu_copy_and_expand_eagle_inputs", torch::kPrivateUse1, &vllm_ascend::npu_copy_and_expand_eagle_inputs);
// causal_conv1d_fn // causal_conv1d_fn
ops.def( ops.def(
"causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, " "causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, "

View File

@@ -458,6 +458,33 @@ void transpose_kv_cache_by_block_meta(
return; return;
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
npu_copy_and_expand_eagle_inputs_meta(
const at::Tensor &target_token_ids,
const at::Tensor &target_positions,
const at::Tensor &next_token_ids,
const at::Tensor &query_start_loc,
const at::Tensor &query_end_loc,
int64_t padding_token_id,
int64_t parallel_drafting_token_id,
int64_t num_padding_slots_per_request,
bool shift_input_ids,
int64_t total_draft_tokens)
{
int64_t total_input_tokens = target_token_ids.size(0);
int64_t num_reqs = query_start_loc.size(0) - 1;
at::Tensor out_input_ids = at::empty({total_draft_tokens}, target_token_ids.options());
at::Tensor out_positions = at::empty({total_draft_tokens}, target_token_ids.options());
at::Tensor out_is_rejected_token_mask = at::empty({total_draft_tokens}, target_token_ids.options().dtype(at::kChar));
at::Tensor out_is_masked_token_mask = at::empty({total_draft_tokens}, target_token_ids.options().dtype(at::kChar));
at::Tensor out_new_token_indices = at::empty({num_reqs * num_padding_slots_per_request}, target_token_ids.options());
at::Tensor out_hidden_state_mapping = at::empty({total_input_tokens}, target_token_ids.options());
return {out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
out_new_token_indices, out_hidden_state_mapping};
}
at::Tensor causal_conv1d_fn_meta( at::Tensor causal_conv1d_fn_meta(
const at::Tensor& mixed_qkv_non_spec_T, const at::Tensor& mixed_qkv_non_spec_T,
const at::Tensor& conv_weights, const at::Tensor& conv_weights,
@@ -543,6 +570,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta); ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
// transpose_kv_cache_by_block // transpose_kv_cache_by_block
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta); ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
// CopyAndExpandEagleInputs
ops.impl("npu_copy_and_expand_eagle_inputs", &vllm_ascend::meta::npu_copy_and_expand_eagle_inputs_meta);
// causal_conv1d_fn // causal_conv1d_fn
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta); ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
// moe_grouped_matmul // moe_grouped_matmul

View File

@@ -0,0 +1,471 @@
"""E2E accuracy test for CopyAndExpandEagleInputs custom operator.
Tests the Ascend C kernel against a CPU golden reference implementation
with parametrized test cases covering various configurations.
"""
import numpy as np
import pytest
import torch
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
SEED = 42
# ---------------------------------------------------------------------------
# Golden reference (CPU, pure Python/NumPy)
# ---------------------------------------------------------------------------
def golden_copy_and_expand(
target_token_ids: np.ndarray,
target_positions: np.ndarray,
next_token_ids: np.ndarray,
query_start_loc: np.ndarray,
query_end_loc: np.ndarray,
padding_token_id: int,
parallel_drafting_token_id: int,
num_padding_slots: int,
shift_input_ids: bool,
):
"""CPU golden reference for CopyAndExpandEagleInputs.
Returns:
(out_input_ids, out_positions, out_is_rejected_token_mask,
out_is_masked_token_mask, out_new_token_indices,
out_hidden_state_mapping)
"""
num_reqs = len(next_token_ids)
# Compute total_draft_tokens
total_draft_tokens = 0
for r in range(num_reqs):
qs = query_start_loc[r]
nqs = query_start_loc[r + 1]
qe = query_end_loc[r]
num_rejected = max(nqs - qe - 1, 0)
if shift_input_ids:
num_valid = max(qe - qs, 0)
else:
num_valid = max(qe - qs + 1, 0)
total_draft_tokens += num_valid + num_padding_slots + num_rejected
out_ids = np.zeros(total_draft_tokens, dtype=np.int32)
out_pos = np.zeros(total_draft_tokens, dtype=np.int32)
out_rej = np.zeros(total_draft_tokens, dtype=np.int8)
out_msk = np.zeros(total_draft_tokens, dtype=np.int8)
out_nti = np.zeros(num_reqs * num_padding_slots, dtype=np.int32)
total_input_tokens = len(target_token_ids)
out_hsm = np.zeros(total_input_tokens, dtype=np.int32)
for r in range(num_reqs):
qs = query_start_loc[r]
nqs = query_start_loc[r + 1]
qe = query_end_loc[r]
num_rejected = max(nqs - qe - 1, 0)
if shift_input_ids:
num_valid = max(qe - qs, 0)
output_start = qs + r * (num_padding_slots - 1)
else:
num_valid = max(qe - qs + 1, 0)
output_start = qs + r * num_padding_slots
start_pos = target_positions[qs]
next_token_id = next_token_ids[r]
# Valid region
if shift_input_ids:
read_start = qs + 1
read_count = min(num_valid, total_input_tokens - read_start)
if read_count < 0:
read_count = 0
for j in range(num_valid):
idx = min(j, read_count - 1) if read_count > 0 else 0
out_ids[output_start + j] = target_token_ids[read_start + idx] if read_count > 0 else 0
out_pos[output_start + j] = start_pos + j
out_rej[output_start + j] = 0
out_msk[output_start + j] = 0
else:
num_input = nqs - qs
for j in range(num_valid):
idx = min(j, num_input - 1)
out_ids[output_start + j] = target_token_ids[qs + idx]
out_pos[output_start + j] = start_pos + j
out_rej[output_start + j] = 0
out_msk[output_start + j] = 0
# Bonus token
out_ids[output_start + num_valid] = next_token_id
out_pos[output_start + num_valid] = start_pos + num_valid
out_rej[output_start + num_valid] = 0
out_msk[output_start + num_valid] = 0
# Parallel draft tokens
for k in range(1, num_padding_slots):
j = num_valid + k
out_ids[output_start + j] = parallel_drafting_token_id
out_pos[output_start + j] = start_pos + j
out_rej[output_start + j] = 0
out_msk[output_start + j] = 1
# Rejected tokens
for k in range(num_rejected):
j = num_valid + num_padding_slots + k
out_ids[output_start + j] = padding_token_id
out_pos[output_start + j] = 0
out_rej[output_start + j] = 1
out_msk[output_start + j] = 0
# New token indices
for k in range(num_padding_slots):
out_nti[r * num_padding_slots + k] = output_start + num_valid + k
# Hidden state mapping (shift_input_ids=true only)
if shift_input_ids:
num_input = nqs - qs
for j in range(num_input):
out_hsm[qs + j] = output_start + j
return out_ids, out_pos, out_rej, out_msk, out_nti, out_hsm
# ---------------------------------------------------------------------------
# NPU operator wrapper
# ---------------------------------------------------------------------------
def npu_op_exec(
target_token_ids, target_positions, next_token_ids,
query_start_loc, query_end_loc,
padding_token_id, parallel_drafting_token_id,
num_padding_slots, shift_input_ids, total_draft_tokens,
):
"""Execute the custom Ascend NPU operator."""
result = torch.ops._C_ascend.npu_copy_and_expand_eagle_inputs(
target_token_ids.to(torch.int32).npu(),
target_positions.to(torch.int32).npu(),
next_token_ids.to(torch.int32).npu(),
query_start_loc.to(torch.int32).npu(),
query_end_loc.to(torch.int32).npu(),
padding_token_id,
parallel_drafting_token_id,
num_padding_slots,
shift_input_ids,
total_draft_tokens,
)
return tuple(t.cpu() for t in result)
# ---------------------------------------------------------------------------
# Test case generator
# ---------------------------------------------------------------------------
def generate_test_case(rng, num_reqs, num_padding_slots, shift_input_ids,
min_tokens_per_req=2, max_tokens_per_req=64,
max_rejected_per_req=5):
"""Generate a random test case.
Returns dict with all input arrays and expected parameters.
"""
padding_token_id = 0
parallel_drafting_token_id = 100
# Generate per-request token counts
tokens_per_req = rng.integers(min_tokens_per_req, max_tokens_per_req + 1,
size=num_reqs)
rejected_per_req = rng.integers(0, max_rejected_per_req + 1, size=num_reqs)
# Build query_start_loc (cumulative)
query_start_loc = np.zeros(num_reqs + 1, dtype=np.int32)
for i in range(num_reqs):
query_start_loc[i + 1] = query_start_loc[i] + tokens_per_req[i] + rejected_per_req[i]
total_input_tokens = int(query_start_loc[num_reqs])
# Build query_end_loc: queryEnd = queryStart + numAccepted - 1
# where numAccepted = tokens_per_req[i]
# For shift=false: numValid = queryEnd - queryStart + 1 = tokens_per_req[i]
# For shift=true: numValid = queryEnd - queryStart = tokens_per_req[i] - 1
query_end_loc = np.zeros(num_reqs, dtype=np.int32)
for i in range(num_reqs):
if shift_input_ids:
query_end_loc[i] = query_start_loc[i] + tokens_per_req[i]
else:
query_end_loc[i] = query_start_loc[i] + tokens_per_req[i] - 1
# Generate input tokens and positions
target_token_ids = rng.integers(1, 50000, size=total_input_tokens, dtype=np.int32)
target_positions = np.zeros(total_input_tokens, dtype=np.int32)
for i in range(num_reqs):
qs = query_start_loc[i]
nqs = query_start_loc[i + 1]
for j in range(nqs - qs):
target_positions[qs + j] = j
next_token_ids = rng.integers(1, 50000, size=num_reqs, dtype=np.int32)
# Compute total_draft_tokens
total_draft_tokens = 0
for r in range(num_reqs):
qs = query_start_loc[r]
nqs = query_start_loc[r + 1]
qe = query_end_loc[r]
num_rejected = max(nqs - qe - 1, 0)
if shift_input_ids:
num_valid = max(qe - qs, 0)
else:
num_valid = max(qe - qs + 1, 0)
total_draft_tokens += num_valid + num_padding_slots + num_rejected
return {
"target_token_ids": target_token_ids,
"target_positions": target_positions,
"next_token_ids": next_token_ids,
"query_start_loc": query_start_loc,
"query_end_loc": query_end_loc,
"padding_token_id": padding_token_id,
"parallel_drafting_token_id": parallel_drafting_token_id,
"num_padding_slots": num_padding_slots,
"shift_input_ids": shift_input_ids,
"total_draft_tokens": total_draft_tokens,
}
# ---------------------------------------------------------------------------
# Parametrized tests
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("num_reqs", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("num_padding_slots", [1, 2, 3, 5])
@pytest.mark.parametrize("shift_input_ids", [False, True])
@pytest.mark.parametrize("seed_offset", [0, 1])
def test_copy_and_expand_eagle_inputs(num_reqs, num_padding_slots,
shift_input_ids, seed_offset):
"""Test CopyAndExpandEagleInputs with parametrized configurations."""
rng = np.random.default_rng(SEED + seed_offset)
case = generate_test_case(rng, num_reqs, num_padding_slots,
shift_input_ids)
# Golden reference
g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand(
case["target_token_ids"],
case["target_positions"],
case["next_token_ids"],
case["query_start_loc"],
case["query_end_loc"],
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
)
# NPU execution
n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec(
torch.from_numpy(case["target_token_ids"]),
torch.from_numpy(case["target_positions"]),
torch.from_numpy(case["next_token_ids"]),
torch.from_numpy(case["query_start_loc"]),
torch.from_numpy(case["query_end_loc"]),
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
case["total_draft_tokens"],
)
# Convert golden to tensors
g_ids_t = torch.from_numpy(g_ids)
g_pos_t = torch.from_numpy(g_pos)
g_rej_t = torch.from_numpy(g_rej)
g_msk_t = torch.from_numpy(g_msk)
g_nti_t = torch.from_numpy(g_nti)
g_hsm_t = torch.from_numpy(g_hsm)
# Compare outputs
torch.testing.assert_close(n_ids, g_ids_t, atol=0, rtol=0,
msg="out_input_ids mismatch")
torch.testing.assert_close(n_pos, g_pos_t, atol=0, rtol=0,
msg="out_positions mismatch")
torch.testing.assert_close(n_rej, g_rej_t, atol=0, rtol=0,
msg="out_is_rejected_token_mask mismatch")
torch.testing.assert_close(n_msk, g_msk_t, atol=0, rtol=0,
msg="out_is_masked_token_mask mismatch")
torch.testing.assert_close(n_nti, g_nti_t, atol=0, rtol=0,
msg="out_new_token_indices mismatch")
if shift_input_ids:
torch.testing.assert_close(n_hsm, g_hsm_t, atol=0, rtol=0,
msg="out_hidden_state_mapping mismatch")
@pytest.mark.parametrize("num_reqs", [1])
@pytest.mark.parametrize("num_padding_slots", [1])
@pytest.mark.parametrize("shift_input_ids", [False, True])
def test_minimal_case(num_reqs, num_padding_slots, shift_input_ids):
"""Test with minimal input (1 request, 1 padding slot)."""
rng = np.random.default_rng(SEED + 100)
case = generate_test_case(rng, num_reqs, num_padding_slots,
shift_input_ids, min_tokens_per_req=2,
max_tokens_per_req=3, max_rejected_per_req=1)
g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand(
case["target_token_ids"],
case["target_positions"],
case["next_token_ids"],
case["query_start_loc"],
case["query_end_loc"],
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
)
n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec(
torch.from_numpy(case["target_token_ids"]),
torch.from_numpy(case["target_positions"]),
torch.from_numpy(case["next_token_ids"]),
torch.from_numpy(case["query_start_loc"]),
torch.from_numpy(case["query_end_loc"]),
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
case["total_draft_tokens"],
)
torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0)
torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0)
torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0)
torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0)
torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0)
@pytest.mark.parametrize("num_reqs", [3, 7, 13])
def test_large_tokens_per_request(num_reqs):
"""Test with larger token counts per request."""
rng = np.random.default_rng(SEED + 200)
case = generate_test_case(rng, num_reqs, num_padding_slots=3,
shift_input_ids=False,
min_tokens_per_req=100,
max_tokens_per_req=512,
max_rejected_per_req=10)
g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand(
case["target_token_ids"],
case["target_positions"],
case["next_token_ids"],
case["query_start_loc"],
case["query_end_loc"],
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
)
n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec(
torch.from_numpy(case["target_token_ids"]),
torch.from_numpy(case["target_positions"]),
torch.from_numpy(case["next_token_ids"]),
torch.from_numpy(case["query_start_loc"]),
torch.from_numpy(case["query_end_loc"]),
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
case["total_draft_tokens"],
)
torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0)
torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0)
torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0)
torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0)
torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0)
@pytest.mark.parametrize("num_reqs", [3, 7, 13])
def test_large_tokens_shift_true(num_reqs):
"""Test with larger token counts and shift_input_ids=True."""
rng = np.random.default_rng(SEED + 300)
case = generate_test_case(rng, num_reqs, num_padding_slots=4,
shift_input_ids=True,
min_tokens_per_req=50,
max_tokens_per_req=256,
max_rejected_per_req=8)
g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand(
case["target_token_ids"],
case["target_positions"],
case["next_token_ids"],
case["query_start_loc"],
case["query_end_loc"],
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
)
n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec(
torch.from_numpy(case["target_token_ids"]),
torch.from_numpy(case["target_positions"]),
torch.from_numpy(case["next_token_ids"]),
torch.from_numpy(case["query_start_loc"]),
torch.from_numpy(case["query_end_loc"]),
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
case["total_draft_tokens"],
)
torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0)
torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0)
torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0)
torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0)
torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0)
torch.testing.assert_close(n_hsm, torch.from_numpy(g_hsm), atol=0, rtol=0)
@pytest.mark.parametrize("num_reqs", [1, 4, 8])
def test_no_rejected_tokens(num_reqs):
"""Test cases with zero rejected tokens."""
rng = np.random.default_rng(SEED + 400)
case = generate_test_case(rng, num_reqs, num_padding_slots=2,
shift_input_ids=False,
min_tokens_per_req=5,
max_tokens_per_req=20,
max_rejected_per_req=0)
g_ids, g_pos, g_rej, g_msk, g_nti, g_hsm = golden_copy_and_expand(
case["target_token_ids"],
case["target_positions"],
case["next_token_ids"],
case["query_start_loc"],
case["query_end_loc"],
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
)
n_ids, n_pos, n_rej, n_msk, n_nti, n_hsm = npu_op_exec(
torch.from_numpy(case["target_token_ids"]),
torch.from_numpy(case["target_positions"]),
torch.from_numpy(case["next_token_ids"]),
torch.from_numpy(case["query_start_loc"]),
torch.from_numpy(case["query_end_loc"]),
case["padding_token_id"],
case["parallel_drafting_token_id"],
case["num_padding_slots"],
case["shift_input_ids"],
case["total_draft_tokens"],
)
torch.testing.assert_close(n_ids, torch.from_numpy(g_ids), atol=0, rtol=0)
torch.testing.assert_close(n_pos, torch.from_numpy(g_pos), atol=0, rtol=0)
torch.testing.assert_close(n_rej, torch.from_numpy(g_rej), atol=0, rtol=0)
torch.testing.assert_close(n_msk, torch.from_numpy(g_msk), atol=0, rtol=0)
torch.testing.assert_close(n_nti, torch.from_numpy(g_nti), atol=0, rtol=0)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import math import math
import os import os
import random import random
from typing import Any, Union from typing import Any
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
@@ -17,23 +17,32 @@ from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = { MODELS = {
#"eagle": { # "eagle": {
# "main": "LLM-Research/Meta-Llama-3.1-8B-Instruct", # "main": "LLM-Research/Meta-Llama-3.1-8B-Instruct",
# "spec": "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B", # "spec": "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B",
#}, # },
"eagle3": { "eagle3": {
"main": "Qwen/Qwen3-8B", "main": "Qwen/Qwen3-8B",
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3", "spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
}, },
} }
DRAFT_PARALLEL_MODELS = {
"draft_parallel": {
"main": "LLM-Research/Meta-Llama-3.1-8B-Instruct",
"spec": "amd/PARD-Llama-3.2-1B",
},
}
# NOTE: golden may change (eagle_proposer only runs in eager mode currently), # NOTE: golden may change (eagle_proposer only runs in eager mode currently),
# thus please update it if ci fails but you have better acceptance # thus please update it if ci fails but you have better acceptance
BASELINES = { BASELINES = {
"eagle": [0.74, 0.44, 0.29], "eagle": [0.74, 0.44, 0.29],
"eagle3": [0.68, 0.40, 0.18], "eagle3": [0.68, 0.40, 0.18],
"draft_parallel": [0.83, 0.50, 0.33, 0.17, 0.17, 0.17, 0.17, 0.00],
} }
@pytest.fixture @pytest.fixture
def test_prompts(): def test_prompts():
prompt_types = ["repeat", "sentence"] prompt_types = ["repeat", "sentence"]
@@ -89,6 +98,7 @@ def eagle3_model_name():
def vl_model_name(): def vl_model_name():
return "Qwen/Qwen3-VL-8B-Instruct" return "Qwen/Qwen3-VL-8B-Instruct"
def vl_eagle3_model_name(): def vl_eagle3_model_name():
return "MNN/Qwen3-VL-8B-Instruct-Eagle3" return "MNN/Qwen3-VL-8B-Instruct-Eagle3"
@@ -98,28 +108,28 @@ def test_ngram_correctness(
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
): ):
''' """
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
''' """
with VllmRunner( with VllmRunner(
model_name, model_name,
max_model_len=1024, max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
) as ref_llm: ) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner( with VllmRunner(
model_name, model_name,
speculative_config={ speculative_config={
"method": "ngram", "method": "ngram",
"prompt_lookup_max": 5, "prompt_lookup_max": 5,
"prompt_lookup_min": 3, "prompt_lookup_min": 3,
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
}, },
max_model_len=1024, max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner: ) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config) spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0 matches = 0
@@ -142,27 +152,27 @@ def test_qwen3_vl_eagle_correctness(
sampling_config: SamplingParams, sampling_config: SamplingParams,
vl_model_name: str, vl_model_name: str,
): ):
''' """
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding. should be the same when using eagle speculative decoding.
''' """
with VllmRunner( with VllmRunner(
vl_model_name, vl_model_name,
max_model_len=1024, max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
) as ref_llm: ) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
spec_model_name = vl_eagle3_model_name() spec_model_name = vl_eagle3_model_name()
with VllmRunner( with VllmRunner(
vl_model_name, vl_model_name,
speculative_config={ speculative_config={
"method": "eagle3", "method": "eagle3",
"model": spec_model_name, "model": spec_model_name,
"num_speculative_tokens": 2, "num_speculative_tokens": 2,
}, },
max_model_len=1024, max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner: ) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config) spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0 matches = 0
@@ -179,27 +189,28 @@ def test_qwen3_vl_eagle_correctness(
# Upon failure, inspect the outputs to check for inaccuracy. # Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs)) assert matches > int(0.66 * len(ref_outputs))
def test_suffix_correctness( def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]], test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
): ):
''' """
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
''' """
with VllmRunner(model_name, with VllmRunner(model_name, max_model_len=1024, cudagraph_capture_sizes=[1, 2, 4, 8]) as ref_llm:
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8]) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner(model_name, with VllmRunner(
speculative_config={ model_name,
"method": "suffix", speculative_config={
"num_speculative_tokens": 8, "method": "suffix",
}, "num_speculative_tokens": 8,
cudagraph_capture_sizes=[1, 2, 4, 8], },
max_model_len=1024) as runner: cudagraph_capture_sizes=[1, 2, 4, 8],
max_model_len=1024,
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config) spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0 matches = 0
misses = 0 misses = 0
@@ -221,22 +232,24 @@ def test_suffix_acceptance(
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
): ):
''' """
Check that suffix decoding caching takes effect and improves acceptance Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts. lengths and acceptance rates over multiple runs of the same prompts.
''' """
num_draft = [] num_draft = []
num_accept = [] num_accept = []
with VllmRunner(model_name, with VllmRunner(
speculative_config={ model_name,
"method": "suffix", speculative_config={
"suffix_decoding_max_spec_factor": 2.0, "method": "suffix",
"suffix_decoding_max_cached_requests": 1000, "suffix_decoding_max_spec_factor": 2.0,
"num_speculative_tokens": 10, "suffix_decoding_max_cached_requests": 1000,
}, "num_speculative_tokens": 10,
max_model_len=1024, },
cudagraph_capture_sizes=[1, 2, 4, 8], max_model_len=1024,
disable_log_stats=False) as runner: cudagraph_capture_sizes=[1, 2, 4, 8],
disable_log_stats=False,
) as runner:
for i in range(10): for i in range(10):
runner.model.chat(test_prompts[i], sampling_config) runner.model.chat(test_prompts[i], sampling_config)
metrics = runner.model.get_metrics() metrics = runner.model.get_metrics()
@@ -271,13 +284,10 @@ def test_suffix_acceptance(
def test_eagle_logprobs( def test_eagle_logprobs(
model_name: str, model_name: str,
use_eagle3: bool, use_eagle3: bool,
draft_tensor_parallel_size: Union[None, int], draft_tensor_parallel_size: None | int,
): ):
prompt = {"role": "user", "content": "Hello world " * 10} prompt = {"role": "user", "content": "Hello world " * 10}
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(temperature=0, logprobs=1, max_tokens=10, ignore_eos=False)
logprobs=1,
max_tokens=10,
ignore_eos=False)
ref_llm = LLM(model=model_name, max_model_len=2048) ref_llm = LLM(model=model_name, max_model_len=2048)
ref_outputs = ref_llm.chat([prompt], sampling_params) ref_outputs = ref_llm.chat([prompt], sampling_params)
@@ -290,19 +300,19 @@ def test_eagle_logprobs(
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner( with VllmRunner(
model_name, model_name,
max_num_seqs=1, max_num_seqs=1,
max_num_batched_tokens=2048, max_num_batched_tokens=2048,
gpu_memory_utilization=0.6, gpu_memory_utilization=0.6,
speculative_config={ speculative_config={
"method": "eagle3" if use_eagle3 else "eagle", "method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name, "model": spec_model_name,
"num_speculative_tokens": 2, "num_speculative_tokens": 2,
"draft_tensor_parallel_size": draft_tensor_parallel_size, "draft_tensor_parallel_size": draft_tensor_parallel_size,
"max_model_len": 128, "max_model_len": 128,
}, },
max_model_len=128, max_model_len=128,
cudagraph_capture_sizes=[1, 2, 4, 8], cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner: ) as runner:
spec_outputs = runner.model.chat([prompt], sampling_params) spec_outputs = runner.model.chat([prompt], sampling_params)
@@ -314,10 +324,7 @@ def test_eagle_logprobs(
spec_logprobs.append(logprobs[token_id]) spec_logprobs.append(logprobs[token_id])
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1)
spec_logprob.logprob,
rel_tol=5e-2,
abs_tol=1e-1)
assert ref_logprob.rank == spec_logprob.rank assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token assert ref_logprob.decoded_token == spec_logprob.decoded_token
@@ -330,7 +337,7 @@ def test_eagle_logprobs(
def test_llama_qwen_eagle_acceptance( def test_llama_qwen_eagle_acceptance(
method: str, method: str,
num_speculative_tokens: int, num_speculative_tokens: int,
draft_tensor_parallel_size: Union[None, int], draft_tensor_parallel_size: None | int,
disable_padded_drafter_batch: bool, disable_padded_drafter_batch: bool,
async_scheduling: bool, async_scheduling: bool,
): ):
@@ -375,7 +382,8 @@ def test_llama_qwen_eagle_acceptance(
[prompt], [prompt],
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
) for prompt in prompts )
for prompt in prompts
] ]
speculative_config = { speculative_config = {
@@ -389,16 +397,16 @@ def test_llama_qwen_eagle_acceptance(
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12]) compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
with VllmRunner( with VllmRunner(
main_model_name, main_model_name,
max_model_len=2048, max_model_len=2048,
disable_log_stats=False, disable_log_stats=False,
tensor_parallel_size=1, tensor_parallel_size=1,
max_num_seqs=256, max_num_seqs=256,
distributed_executor_backend="mp", distributed_executor_backend="mp",
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
speculative_config=speculative_config, speculative_config=speculative_config,
compilation_config=compilation_config, compilation_config=compilation_config,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
) as llm: ) as llm:
outputs = llm.model.generate(prompts, sampling_params) outputs = llm.model.generate(prompts, sampling_params)
metrics = llm.model.get_metrics() metrics = llm.model.get_metrics()
@@ -419,10 +427,7 @@ def test_llama_qwen_eagle_acceptance(
for pos in range(len(metric.values)): for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos] num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [ acceptance_per_pos = [num_accepted_tokens / num_drafts for num_accepted_tokens in num_accepted_tokens_per_pos]
num_accepted_tokens / num_drafts
for num_accepted_tokens in num_accepted_tokens_per_pos
]
if method == "eagle": if method == "eagle":
golden = [0.7313432835820896, 0.373134328358209, 0.19402985074626866] golden = [0.7313432835820896, 0.373134328358209, 0.19402985074626866]
else: else:
@@ -434,3 +439,98 @@ def test_llama_qwen_eagle_acceptance(
print(f"golden: {golden}") print(f"golden: {golden}")
assert match assert match
@pytest.mark.parametrize("method", DRAFT_PARALLEL_MODELS.keys())
@pytest.mark.parametrize("num_speculative_tokens", [8])
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
def test_parallel_drafting_acceptance(
method: str,
num_speculative_tokens: int,
draft_tensor_parallel_size: None | int,
):
"""
Test acceptance rate for parallel drafting speculative decoding
using a smaller draft model with parallel_drafting enabled.
"""
main_model_name = DRAFT_PARALLEL_MODELS[method]["main"]
spec_model_name = DRAFT_PARALLEL_MODELS[method]["spec"]
tokenizer = AutoTokenizer.from_pretrained(
main_model_name,
trust_remote_code=True,
)
sampling_params = SamplingParams(
temperature=0,
ignore_eos=False,
max_tokens=256,
)
prompts = [
{
"role": "user",
"content": "Hello, your name is",
},
]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
)
for prompt in prompts
]
speculative_config = {
"method": "draft_model",
"model": spec_model_name,
"num_speculative_tokens": num_speculative_tokens,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"parallel_drafting": True,
}
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
with VllmRunner(
main_model_name,
max_model_len=4096,
disable_log_stats=False,
tensor_parallel_size=1,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
compilation_config=compilation_config,
enable_prefix_caching=False,
) as llm:
outputs = llm.model.generate(prompts, sampling_params)
metrics = llm.model.get_metrics()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
output_tokens = output.outputs[0].token_ids
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Output tokens: {output_tokens}")
num_drafts = 0
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [num_accepted_tokens / num_drafts for num_accepted_tokens in num_accepted_tokens_per_pos]
golden = BASELINES[method]
match = all(abs(a - b) < 0.1 for a, b in zip(acceptance_per_pos, golden))
if not match:
print(f"acceptance_per_pos: {acceptance_per_pos}")
print(f"golden: {golden}")
assert match

View File

@@ -10,14 +10,15 @@ from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
class TestEagleProposerInitialization(TestBase): class TestEagleProposerInitialization(TestBase):
def setUp(self): def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config = MagicMock()
self.vllm_config.cache_config = MagicMock(spec=CacheConfig) self.vllm_config.cache_config = MagicMock(spec=CacheConfig)
self.vllm_config.scheduler_config = MagicMock() self.vllm_config.scheduler_config = MagicMock()
self.vllm_config.model_config = MagicMock() self.vllm_config.model_config = MagicMock()
self.vllm_config.model_config.hf_text_config = MagicMock(spec=[]) # Empty spec to prevent hasattr from returning True self.vllm_config.model_config.hf_text_config = MagicMock(
spec=[]
) # Empty spec to prevent hasattr from returning True
self.vllm_config.model_config.hf_text_config.to_dict = MagicMock(return_value={}) self.vllm_config.model_config.hf_text_config.to_dict = MagicMock(return_value={})
self.vllm_config.compilation_config = MagicMock() self.vllm_config.compilation_config = MagicMock()
self.device = torch.device("cpu") self.device = torch.device("cpu")
@@ -40,20 +41,16 @@ class TestEagleProposerInitialization(TestBase):
self.vllm_config.parallel_config.enable_expert_parallel = False self.vllm_config.parallel_config.enable_expert_parallel = False
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([ self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(2)])
(i + 1) * (0, ) for i in range(2)
])
self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0
self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.draft_model_config.uses_mrope = False
self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False
self.vllm_config.additional_config = None self.vllm_config.additional_config = None
self.mock_cpugpubuffer = patch( self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
self.mock_cpugpubuffer.start() self.mock_cpugpubuffer.start()
self.mock_supports_multimodal_inputs = patch( self.mock_supports_multimodal_inputs = patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False
return_value=False
) )
self.mock_supports_multimodal_inputs.start() self.mock_supports_multimodal_inputs.start()
@@ -78,18 +75,16 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
proposer = AscendEagleProposer(vllm_config=self.vllm_config, proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner)
device=self.device,
runner=self.runner)
self.assertEqual(proposer.hidden_size, 4096) self.assertEqual(proposer.hidden_size, 4096)
self.assertTrue(proposer.use_cuda_graph) self.assertTrue(proposer.use_cuda_graph)
expected_max_num_tokens = proposer.max_num_tokens expected_max_num_tokens = proposer.max_num_tokens
self.assertEqual(proposer.input_ids.shape, (expected_max_num_tokens, )) self.assertEqual(proposer.input_ids.shape, (expected_max_num_tokens,))
self.assertEqual(proposer.positions.shape, (expected_max_num_tokens, )) self.assertEqual(proposer.positions.shape, (expected_max_num_tokens,))
self.assertEqual(proposer.hidden_states.shape, (expected_max_num_tokens, 4096)) self.assertEqual(proposer.hidden_states.shape, (expected_max_num_tokens, 4096))
self.assertEqual(proposer.arange.shape, (expected_max_num_tokens, )) self.assertEqual(proposer.arange.shape, (expected_max_num_tokens,))
def test_initialization_eagle3_enforce_eager(self): def test_initialization_eagle3_enforce_eager(self):
self.vllm_config.speculative_config.method = "eagle3" self.vllm_config.speculative_config.method = "eagle3"
@@ -101,9 +96,7 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
proposer = AscendEagleProposer(vllm_config=self.vllm_config, proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner)
device=self.device,
runner=self.runner)
self.assertEqual(proposer.hidden_size, 2048) self.assertEqual(proposer.hidden_size, 2048)
self.assertFalse(proposer.use_cuda_graph) self.assertFalse(proposer.use_cuda_graph)
@@ -120,9 +113,7 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
proposer = AscendEagleProposer(vllm_config=self.vllm_config, proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner)
device=self.device,
runner=self.runner)
self.assertEqual(proposer.hidden_size, 2048) self.assertEqual(proposer.hidden_size, 2048)
self.assertTrue(proposer.use_cuda_graph) self.assertTrue(proposer.use_cuda_graph)
@@ -139,9 +130,7 @@ class TestEagleProposerInitialization(TestBase):
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
proposer = AscendEagleProposer(vllm_config=self.vllm_config, proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner)
device=self.device,
runner=self.runner)
self.assertEqual(proposer.hidden_size, 2048) self.assertEqual(proposer.hidden_size, 2048)
self.assertFalse(proposer.use_cuda_graph) self.assertFalse(proposer.use_cuda_graph)
@@ -150,7 +139,6 @@ class TestEagleProposerInitialization(TestBase):
class TestEagleProposerLoadModel(TestBase): class TestEagleProposerLoadModel(TestBase):
def setUp(self): def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config = MagicMock()
@@ -175,29 +163,24 @@ class TestEagleProposerLoadModel(TestBase):
self.vllm_config.parallel_config.enable_expert_parallel = False self.vllm_config.parallel_config.enable_expert_parallel = False
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([ self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(2)])
(i + 1) * (0, ) for i in range(2)
])
self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0
self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.draft_model_config.uses_mrope = False
self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False
self.vllm_config.additional_config = None self.vllm_config.additional_config = None
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
self.mock_cpugpubuffer = patch( self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
self.mock_cpugpubuffer.start() self.mock_cpugpubuffer.start()
self.mock_supports_multimodal_inputs = patch( self.mock_supports_multimodal_inputs = patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False
return_value=False
) )
self.mock_supports_multimodal_inputs.start() self.mock_supports_multimodal_inputs.start()
# Set the current vllm config # Set the current vllm config
set_current_vllm_config(self.vllm_config) set_current_vllm_config(self.vllm_config)
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner)
device=self.device, self.proposer.parallel_drafting = False
runner=self.runner)
def tearDown(self): def tearDown(self):
self.mock_cpugpubuffer.stop() self.mock_cpugpubuffer.stop()
@@ -205,24 +188,21 @@ class TestEagleProposerLoadModel(TestBase):
# Clear the current vllm config # Clear the current vllm config
set_current_vllm_config(None) set_current_vllm_config(None)
@patch( @patch("vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
def test_load_model_pp1(self, mock_pp_group, mock_get_model, def test_load_model_pp1(self, mock_pp_group, mock_get_model, mock_get_layers):
mock_get_layers):
mock_pp_group.return_value.world_size = 1 mock_pp_group.return_value.world_size = 1
mock_target_layer1 = MagicMock() mock_target_layer1 = MagicMock()
mock_target_layer2 = MagicMock() mock_target_layer2 = MagicMock()
mock_draft_layer1 = MagicMock() mock_draft_layer1 = MagicMock()
mock_draft_layer3 = MagicMock() mock_draft_layer3 = MagicMock()
mock_get_layers.side_effect = [{ mock_get_layers.side_effect = [
"layer1": mock_target_layer1, {"layer1": mock_target_layer1, "layer2": mock_target_layer2},
"layer2": mock_target_layer2 {},
}, {}, {}, { {},
"layer1": mock_draft_layer1, {"layer1": mock_draft_layer1, "layer3": mock_draft_layer3},
"layer3": mock_draft_layer3 ]
}]
weight = torch.zeros(0) weight = torch.zeros(0)
@@ -241,61 +221,45 @@ class TestEagleProposerLoadModel(TestBase):
self.proposer.load_model(mock_model) self.proposer.load_model(mock_model)
mock_get_model.assert_called_once() mock_get_model.assert_called_once()
self.assertEqual(self.proposer.attn_layer_names, ["layer3"]) self.assertEqual(self.proposer.attn_layer_names, ["layer3"])
self.assertIs(self.proposer.model.model.embed_tokens, self.assertIs(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens)
mock_model.model.embed_tokens)
@patch( @patch("vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, mock_get_layers):
mock_get_layers):
mock_pp_group.return_value.world_size = 2 mock_pp_group.return_value.world_size = 2
mock_target_layer1 = MagicMock() mock_target_layer1 = MagicMock()
mock_draft_layer2 = MagicMock() mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [{ mock_get_layers.side_effect = [{"layer1": mock_target_layer1}, {}, {}, {"layer2": mock_draft_layer2}]
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_model = MagicMock() mock_model = MagicMock()
original_embed = MagicMock() original_embed = MagicMock()
mock_model.multimodal_cpu_fields = None mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None mock_model.merge_by_field_config = None
mock_get_model.return_value = MagicMock(model=MagicMock( mock_get_model.return_value = MagicMock(model=MagicMock(embed_tokens=original_embed))
embed_tokens=original_embed))
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
self.proposer.load_model(mock_model) self.proposer.load_model(mock_model)
self.assertIsNot(self.proposer.model.model.embed_tokens, self.assertIsNot(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens)
mock_model.model.embed_tokens)
self.assertEqual(self.proposer.attn_layer_names, ["layer2"]) self.assertEqual(self.proposer.attn_layer_names, ["layer2"])
@patch( @patch("vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group")
@patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal") @patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal")
def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, mock_get_model, mock_get_layers):
mock_get_model, mock_get_layers):
mock_model = MagicMock() mock_model = MagicMock()
mock_model.get_language_model.return_value.lm_head = MagicMock() mock_model.get_language_model.return_value.lm_head = MagicMock()
mock_supports_multi.return_value = True mock_supports_multi.return_value = True
original_embed = MagicMock() original_embed = MagicMock()
mock_get_model.return_value = MagicMock(model=MagicMock( mock_get_model.return_value = MagicMock(model=MagicMock(embed_tokens=original_embed))
embed_tokens=original_embed))
mock_target_layer1 = MagicMock() mock_target_layer1 = MagicMock()
mock_draft_layer2 = MagicMock() mock_draft_layer2 = MagicMock()
mock_get_layers.side_effect = [{ mock_get_layers.side_effect = [{"layer1": mock_target_layer1}, {}, {}, {"layer2": mock_draft_layer2}]
"layer1": mock_target_layer1
}, {}, {}, {
"layer2": mock_draft_layer2
}]
mock_pp_group.return_value.world_size = 2 mock_pp_group.return_value.world_size = 2
self.proposer.model = MagicMock() self.proposer.model = MagicMock()
@@ -303,12 +267,10 @@ class TestEagleProposerLoadModel(TestBase):
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
self.proposer.load_model(mock_model) self.proposer.load_model(mock_model)
self.assertEqual(mock_model.get_language_model.call_count, 2) self.assertEqual(mock_model.get_language_model.call_count, 2)
self.assertIs(self.proposer.model.lm_head, self.assertIs(self.proposer.model.lm_head, mock_model.get_language_model.return_value.lm_head)
mock_model.get_language_model.return_value.lm_head)
class TestEagleProposerDummyRun(TestBase): class TestEagleProposerDummyRun(TestBase):
def setUp(self): def setUp(self):
self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config = MagicMock(spec=VllmConfig)
self.vllm_config.speculative_config = MagicMock() self.vllm_config.speculative_config = MagicMock()
@@ -328,51 +290,43 @@ class TestEagleProposerDummyRun(TestBase):
self.vllm_config.model_config.uses_mrope = False self.vllm_config.model_config.uses_mrope = False
self.vllm_config.model_config.uses_xdrope_dim = 0 self.vllm_config.model_config.uses_xdrope_dim = 0
self.vllm_config.model_config.use_mla = False self.vllm_config.model_config.use_mla = False
self.vllm_config.model_config.hf_text_config = MagicMock(spec=[]) # Empty spec to prevent hasattr from returning True self.vllm_config.model_config.hf_text_config = MagicMock(
spec=[]
) # Empty spec to prevent hasattr from returning True
self.vllm_config.model_config.hf_text_config.to_dict = MagicMock(return_value={}) self.vllm_config.model_config.hf_text_config.to_dict = MagicMock(return_value={})
self.vllm_config.parallel_config.tensor_parallel_size = 1 self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.parallel_config.data_parallel_rank = 0 self.vllm_config.parallel_config.data_parallel_rank = 0
self.vllm_config.parallel_config.data_parallel_size = 1 self.vllm_config.parallel_config.data_parallel_size = 1
self.vllm_config.parallel_config.prefill_context_parallel_size = 1 self.vllm_config.parallel_config.prefill_context_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.speculative_token_tree = str([ self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(4)])
(i + 1) * (0, ) for i in range(4)
])
self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0
self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.draft_model_config.uses_mrope = False
self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False
self.vllm_config.additional_config = None self.vllm_config.additional_config = None
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
self.mock_cpugpubuffer = patch( self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
self.mock_cpugpubuffer.start() self.mock_cpugpubuffer.start()
self.mock_supports_multimodal_inputs = patch( self.mock_supports_multimodal_inputs = patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False
return_value=False
) )
self.mock_supports_multimodal_inputs.start() self.mock_supports_multimodal_inputs.start()
# Mock parallel state functions # Mock parallel state functions
self.mock_tp_world_size = patch( self.mock_tp_world_size = patch(
"vllm_ascend.ascend_forward_context.get_tensor_model_parallel_world_size", "vllm_ascend.ascend_forward_context.get_tensor_model_parallel_world_size", return_value=1
return_value=1
) )
self.mock_tp_world_size.start() self.mock_tp_world_size.start()
mock_dp_group = MagicMock() mock_dp_group = MagicMock()
mock_dp_group.world_size = 1 mock_dp_group.world_size = 1
self.mock_dp_group = patch( self.mock_dp_group = patch("vllm_ascend.ascend_forward_context.get_dp_group", return_value=mock_dp_group)
"vllm_ascend.ascend_forward_context.get_dp_group",
return_value=mock_dp_group
)
self.mock_dp_group.start() self.mock_dp_group.start()
# Set the current vllm config # Set the current vllm config
set_current_vllm_config(self.vllm_config) set_current_vllm_config(self.vllm_config)
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner)
device=self.device,
runner=self.runner)
self.proposer.model = MagicMock() self.proposer.model = MagicMock()
self.proposer._runnable = MagicMock() self.proposer._runnable = MagicMock()
self.proposer.update_stream = MagicMock() self.proposer.update_stream = MagicMock()
@@ -397,8 +351,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
self.proposer.enable_shared_expert_dp = False self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=num_tokens, self.proposer.dummy_run(num_tokens=num_tokens, with_prefill=with_prefill)
with_prefill=with_prefill)
self.assertTrue(self.proposer._runnable.call_count == 1) self.assertTrue(self.proposer._runnable.call_count == 1)
@@ -433,9 +386,7 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
self.proposer.enable_shared_expert_dp = False self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=64, self.proposer.dummy_run(num_tokens=64, in_graph_capturing=True, aclgraph_runtime_mode=CUDAGraphMode.FULL)
in_graph_capturing=True,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1) self.assertTrue(self.proposer._runnable.call_count == 1)
mock_update_full_graph_params.assert_not_called() mock_update_full_graph_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph self.proposer.use_cuda_graph = last_use_cuda_graph
@@ -458,16 +409,13 @@ class TestEagleProposerDummyRun(TestBase):
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
self.proposer.enable_shared_expert_dp = False self.proposer.enable_shared_expert_dp = False
self.proposer.dummy_run(num_tokens=64, self.proposer.dummy_run(num_tokens=64, in_graph_capturing=False, aclgraph_runtime_mode=CUDAGraphMode.FULL)
in_graph_capturing=False,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1) self.assertTrue(self.proposer._runnable.call_count == 1)
self.assertTrue(mock_update_full_graph_params.call_count == 1) self.assertTrue(mock_update_full_graph_params.call_count == 1)
self.proposer.use_cuda_graph = last_use_cuda_graph self.proposer.use_cuda_graph = last_use_cuda_graph
class TestEagleProposerHelperMethods(TestBase): class TestEagleProposerHelperMethods(TestBase):
# TODO: Can add some tests about prepare_next_token_ids in future. # TODO: Can add some tests about prepare_next_token_ids in future.
def setUp(self): def setUp(self):
@@ -497,29 +445,23 @@ class TestEagleProposerHelperMethods(TestBase):
self.vllm_config.parallel_config.enable_expert_parallel = False self.vllm_config.parallel_config.enable_expert_parallel = False
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1 self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2 self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([ self.vllm_config.speculative_config.speculative_token_tree = str([(i + 1) * (0,) for i in range(2)])
(i + 1) * (0, ) for i in range(2)
])
self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0 self.vllm_config.speculative_config.draft_model_config.uses_xdrope_dim = 0
self.vllm_config.speculative_config.draft_model_config.uses_mrope = False self.vllm_config.speculative_config.draft_model_config.uses_mrope = False
self.vllm_config.speculative_config.disable_padded_drafter_batch = False self.vllm_config.speculative_config.disable_padded_drafter_batch = False
self.vllm_config.additional_config = None self.vllm_config.additional_config = None
init_ascend_config(self.vllm_config) init_ascend_config(self.vllm_config)
self.mock_cpugpubuffer = patch( self.mock_cpugpubuffer = patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
self.mock_cpugpubuffer.start() self.mock_cpugpubuffer.start()
self.mock_supports_multimodal_inputs = patch( self.mock_supports_multimodal_inputs = patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", return_value=False
return_value=False
) )
self.mock_supports_multimodal_inputs.start() self.mock_supports_multimodal_inputs.start()
# Set the current vllm config # Set the current vllm config
set_current_vllm_config(self.vllm_config) set_current_vllm_config(self.vllm_config)
self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, self.proposer = AscendEagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner)
device=self.device,
runner=self.runner)
def tearDown(self): def tearDown(self):
self.mock_cpugpubuffer.stop() self.mock_cpugpubuffer.stop()
@@ -536,11 +478,9 @@ class TestEagleProposerHelperMethods(TestBase):
num_rejected = torch.tensor([1, 0, 1], device=self.device) num_rejected = torch.tensor([1, 0, 1], device=self.device)
mock_return_attn = MagicMock() mock_return_attn = MagicMock()
with set_current_vllm_config(self.vllm_config): with (
with patch.object(self.proposer, set_current_vllm_config(self.vllm_config),
'prepare_inputs', patch.object(self.proposer, "prepare_inputs", return_value=(mock_return_attn, torch.tensor([1, 2, 4]))),
return_value=(mock_return_attn, ):
torch.tensor([1, 2, 4]))): return_attn, indices = self.proposer.prepare_inputs(mock_attn, num_rejected)
return_attn, indices = self.proposer.prepare_inputs( self.assertEqual(indices.tolist(), [1, 2, 4])
mock_attn, num_rejected)
self.assertEqual(indices.tolist(), [1, 2, 4])

View File

@@ -284,6 +284,9 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
if isinstance(self.kv_cache_spec, CrossAttentionSpec): if isinstance(self.kv_cache_spec, CrossAttentionSpec):
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32) slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
elif self.speculative_config and self.speculative_config.parallel_drafting:
seq_lens = common_attn_metadata.seq_lens
attn_state = common_attn_metadata.attn_state attn_state = common_attn_metadata.attn_state
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder # Get attn_mask and swa_mask from singleton AttentionMaskBuilder

View File

@@ -24,6 +24,7 @@ def prepare_inputs_padded_kernel(
valid_sampled_tokens_count_ptr, # [num_reqs] valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1] query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output) token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr,
num_reqs, # tl.int32 num_reqs, # tl.int32
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
@@ -61,3 +62,4 @@ def prepare_inputs_padded_kernel(
index_to_sample = q_last_tok_idx - num_rejected index_to_sample = q_last_tok_idx - num_rejected
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask) tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)
tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask)

View File

@@ -16,6 +16,8 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# #
from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
@@ -31,5 +33,7 @@ def get_spec_decode_method(method, vllm_config, device, runner):
return AscendMedusaProposer(vllm_config, device) return AscendMedusaProposer(vllm_config, device)
elif method in ("eagle", "eagle3", "mtp"): elif method in ("eagle", "eagle3", "mtp"):
return AscendEagleProposer(vllm_config, device, runner) return AscendEagleProposer(vllm_config, device, runner)
elif method == "draft_model":
return AscendDraftModelProposer(vllm_config, device, runner)
else: else:
raise ValueError(f"Unknown speculative decoding method: {method}") raise ValueError(f"Unknown speculative decoding method: {method}")

View File

@@ -0,0 +1,71 @@
import torch
import torch.nn as nn
from typing_extensions import override
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
logger = init_logger(__name__)
class AscendDraftModelProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=False,
runner=runner,
)
self._raise_if_vocab_size_mismatch()
self._raise_if_draft_tp_mismatch()
def _raise_if_vocab_size_mismatch(self):
self.speculative_config.verify_equal_vocab_size_if_draft_model()
def _raise_if_draft_tp_mismatch(self):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# the draft model with TP = 1, then the different TP ranks collide.
# Specifically when all ranks compile the draft model on rank 0
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg = self.speculative_config
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp:
raise ValueError(
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
f"must be the same. Got {draft_tp} and {tgt_tp}. "
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
def _get_model(self) -> nn.Module:
# Draft models may be quantized or on different parallelism,
# so we load them with a modified vllm config
from vllm.compilation.backends import set_model_tag
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
with set_model_tag("draft_model"):
model = get_model(
vllm_config=temp_vllm_config,
prefix="draft_model",
)
return model
@override
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
# Draft models don't share embeddings with the target model
pass
@override
def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
# Draft models don't share lm_head with the target model
pass

View File

@@ -30,8 +30,13 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID,
compute_new_slot_mapping,
extend_all_queries_by_N,
)
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, set_ascend_forward_context from vllm_ascend.ascend_forward_context import _EXTRA_CTX, set_ascend_forward_context
@@ -80,14 +85,14 @@ def split_inputs_tp_to_sp(hidden_states, out):
return out[:padded_num_tokens_per_rank] return out[:padded_num_tokens_per_rank]
class AscendEagleProposer(EagleProposer): class SpecDecodeBaseProposer(EagleProposer):
_runnable: ACLGraphWrapper | Callable _runnable: ACLGraphWrapper | Callable
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): def __init__(self, vllm_config: VllmConfig, device: torch.device, pass_hidden_states_to_model: bool, runner=None):
super().__init__(vllm_config, device, runner) super().__init__(vllm_config, device, runner)
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.decode_threshold = 1 + self.num_speculative_tokens self.decode_threshold = 1 + self.num_speculative_tokens
self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32) self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32)
self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32) self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32)
@@ -140,7 +145,7 @@ class AscendEagleProposer(EagleProposer):
if not self.use_cuda_graph and enable_sp(vllm_config): if not self.use_cuda_graph and enable_sp(vllm_config):
self.maybe_eager_context = _maybe_eager_context(vllm_config) self.maybe_eager_context = _maybe_eager_context(vllm_config)
self.last_token_indices = torch.zeros( self.token_indices_to_sample = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.int32, device=device self.vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.int32, device=device
) )
slot_mapping_lens = self.runner.max_num_tokens + 2 * self.pcp_size * self.runner.max_num_reqs slot_mapping_lens = self.runner.max_num_tokens + 2 * self.pcp_size * self.runner.max_num_reqs
@@ -150,15 +155,38 @@ class AscendEagleProposer(EagleProposer):
] ]
self._runnable = self._run_merged_draft self._runnable = self._run_merged_draft
if self.uses_mrope:
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int32, device=device)
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions = torch.zeros(
(self.uses_xdrope_dim, self.max_num_tokens + 1),
dtype=torch.int32,
device=device,
)
else:
# RoPE need (max_num_tokens,)
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
def _get_model(self) -> nn.Module:
"""
Default method to call get_model(). Can be overridden by subclasses which
need to customize model loading.
"""
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.speculative_config.draft_model_config,
)
return model
def load_model(self, model: nn.Module) -> None: def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()) target_attn_layer_names = set(get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys())
target_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()) target_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys())
with self.maybe_eager_context: with self.maybe_eager_context:
self.model = get_model( self.model = self._get_model()
vllm_config=self.vllm_config, model_config=self.vllm_config.speculative_config.draft_model_config
)
indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys() indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()
draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
@@ -167,7 +195,7 @@ class AscendEagleProposer(EagleProposer):
draft_attn_layer_names = draft_attn_layers - target_attn_layer_names draft_attn_layer_names = draft_attn_layers - target_attn_layer_names
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_names = list(sorted(draft_attn_layer_names)) self.attn_layer_names = list(sorted(draft_attn_layer_names))
self.kernel_block_size = ( self.kernel_block_size = (
@@ -202,6 +230,24 @@ class AscendEagleProposer(EagleProposer):
target_language_model = model target_language_model = model
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
self._maybe_share_embeddings(target_language_model)
self._maybe_share_lm_head(model)
if self.parallel_drafting and self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(self.model.mask_hidden.view(3 * self.hidden_size))
if self.eagle3_use_aux_hidden_state
else self.model.mask_hidden.view(self.hidden_size)
)
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
"""
Some draft models may not have their own embedding layers, and some may
have a duplicate copy of the target model's embedding layers. In these cases,
we share the target model's embedding layers with the draft model to save
memory.
"""
if get_pp_group().world_size == 1: if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"): if hasattr(target_language_model.model, "embed_tokens"):
target_embed_tokens = target_language_model.model.embed_tokens target_embed_tokens = target_language_model.model.embed_tokens
@@ -256,7 +302,9 @@ class AscendEagleProposer(EagleProposer):
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" "Since PP > 1 or other reasons the model head loaded its own vocab embedding"
" weights instead of sharing them with the target model." " weights instead of sharing them with the target model."
) )
# share lm_head with the target model if needed
# share lm_head with the target model if needed
def _maybe_share_lm_head(self, model: nn.Module) -> None:
# some model definition do not define lm_head explicitly # some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.method == "eagle" and hasattr(model, "lm_head"): if self.method == "eagle" and hasattr(model, "lm_head"):
@@ -389,7 +437,7 @@ class AscendEagleProposer(EagleProposer):
self._runnable( self._runnable(
num_input_tokens=num_tokens, num_input_tokens=num_tokens,
batch_size=batch_size, batch_size=batch_size,
last_token_indices=self.last_token_indices[:batch_size], token_indices_to_sample=self.token_indices_to_sample[: batch_size * self.extra_slots_per_request],
# The target_position's address is same as the model_positions's # The target_position's address is same as the model_positions's
target_positions=model_positions, target_positions=model_positions,
inputs_embeds=None, inputs_embeds=None,
@@ -411,7 +459,7 @@ class AscendEagleProposer(EagleProposer):
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
# [batch_size] # [batch_size]
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
last_token_indices: torch.Tensor | None, token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
@@ -421,31 +469,34 @@ class AscendEagleProposer(EagleProposer):
num_decode_reqs=0, num_decode_reqs=0,
scheduler_output: SchedulerOutput = None, scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0, num_scheduled_tokens: int = 0,
num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] batch_size = common_attn_metadata.batch_size()
batch_size = next_token_ids.shape[0]
if last_token_indices is None: if token_indices_to_sample is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3": if self.method == "eagle3":
assert isinstance(self.get_model(), Eagle3LlamaForCausalLM) assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(target_hidden_states) target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token. num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass(
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] target_token_ids=target_token_ids,
self.input_ids[: num_tokens - 1] = target_token_ids[1:] next_token_ids=next_token_ids,
# Replace the last token with the next token. target_positions=target_positions,
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] target_hidden_states=target_hidden_states,
self.input_ids[last_token_indices] = next_token_ids token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
assert self.runner is not None assert self.runner is not None
# update pcp related params # update pcp related params
if self.pcp_size * self.dcp_size > 1: if self.pcp_size * self.dcp_size > 1:
assert long_seq_metadata is not None assert long_seq_metadata is not None
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
ori_last_token_indices = last_token_indices.clone() ori_token_indices_to_sample = token_indices_to_sample.clone()
query_lens_d = self.runner.query_lens[:num_decode_reqs] query_lens_d = self.runner.query_lens[:num_decode_reqs]
if self.pcp_size > 1: if self.pcp_size > 1:
# 1. preprocess decode/prefill input_ids & target_hidden_states # 1. preprocess decode/prefill input_ids & target_hidden_states
@@ -484,9 +535,11 @@ class AscendEagleProposer(EagleProposer):
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
# 2. update sample_indices according to main model # 2. update sample_indices according to main model
if num_decode_reqs: if num_decode_reqs:
last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]] token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[
token_indices_to_sample[:num_decode_reqs]
]
if num_prefill_reqs: if num_prefill_reqs:
last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
# 3. update attn_metadata params that may be influenced by pcp # 3. update attn_metadata params that may be influenced by pcp
common_attn_metadata.num_actual_tokens = num_tokens common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p) common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p)
@@ -530,10 +583,6 @@ class AscendEagleProposer(EagleProposer):
aclgraph_runtime_mode = CUDAGraphMode.NONE aclgraph_runtime_mode = CUDAGraphMode.NONE
batch_descriptor = None batch_descriptor = None
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
if self.supports_mm_inputs: if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
inputs_embeds = self.model.embed_input_ids( inputs_embeds = self.model.embed_input_ids(
@@ -559,15 +608,16 @@ class AscendEagleProposer(EagleProposer):
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
if self.uses_mrope: if self.uses_mrope:
used_update_positions = target_positions[:, last_token_indices] used_update_positions = self.mrope_positions[:, token_indices_to_sample]
else: else:
used_update_positions = target_positions[last_token_indices] used_update_positions = self.positions[token_indices_to_sample]
per_layer_attn_metadata = dict() per_layer_attn_metadata = dict()
# The first step of speculative. # The first step of speculative.
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata = [per_layer_attn_metadata] multi_steps_attn_metadata = [per_layer_attn_metadata]
# Copy the old attn_metadata and update
attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]] attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]]
if self.pcp_size * self.dcp_size > 1: if self.pcp_size * self.dcp_size > 1:
if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills: if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills:
@@ -578,7 +628,7 @@ class AscendEagleProposer(EagleProposer):
# to get corresponding slot_mapping in each step. # to get corresponding slot_mapping in each step.
num_reject_tokens = ( num_reject_tokens = (
torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device) torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device)
- ori_last_token_indices - ori_token_indices_to_sample
- 1 - 1
) )
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
@@ -616,6 +666,27 @@ class AscendEagleProposer(EagleProposer):
common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size] common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size]
# Copy the old attn_metadata and update # Copy the old attn_metadata and update
if not self.parallel_drafting:
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
ori_seq_len,
slot_indices,
mtp_slot_mapping,
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
else:
# Copy the old attn_metadata and update
if not self.parallel_drafting:
for draft_step in range(1, self.num_speculative_tokens): for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step, draft_step,
@@ -625,33 +696,14 @@ class AscendEagleProposer(EagleProposer):
num_input_tokens, num_input_tokens,
used_update_positions, used_update_positions,
aclgraph_runtime_mode, aclgraph_runtime_mode,
ori_seq_len,
slot_indices,
mtp_slot_mapping,
) )
per_layer_attn_metadata = dict() per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata) multi_steps_attn_metadata.append(per_layer_attn_metadata)
else:
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
last_token_indices_len = last_token_indices.shape[0] token_indices_to_sample_len = token_indices_to_sample.shape[0]
self.last_token_indices[:last_token_indices_len].copy_(last_token_indices) self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample)
with set_ascend_forward_context( with set_ascend_forward_context(
multi_steps_attn_metadata[0], multi_steps_attn_metadata[0],
@@ -672,7 +724,7 @@ class AscendEagleProposer(EagleProposer):
draft_token_ids = self._runnable( draft_token_ids = self._runnable(
num_input_tokens=num_input_tokens, num_input_tokens=num_input_tokens,
batch_size=batch_size, batch_size=batch_size,
last_token_indices=self.last_token_indices[:last_token_indices_len], token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len],
target_positions=target_positions, target_positions=target_positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multi_steps_attn_metadata=multi_steps_attn_metadata, multi_steps_attn_metadata=multi_steps_attn_metadata,
@@ -689,7 +741,7 @@ class AscendEagleProposer(EagleProposer):
self, self,
num_input_tokens, num_input_tokens,
batch_size, batch_size,
last_token_indices, token_indices_to_sample,
target_positions, target_positions,
inputs_embeds, inputs_embeds,
multi_steps_attn_metadata, multi_steps_attn_metadata,
@@ -702,17 +754,22 @@ class AscendEagleProposer(EagleProposer):
# `model_hidden_states` represent the speculative model inputs. # `model_hidden_states` represent the speculative model inputs.
model_input_ids = self.input_ids[:num_input_tokens] model_input_ids = self.input_ids[:num_input_tokens]
model_positions = self._get_positions(num_input_tokens) model_positions = self._get_positions(num_input_tokens)
model_hidden_states = self.hidden_states[:num_input_tokens]
model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions) model_kwargs = {
"input_ids": model_input_ids,
"positions": model_positions,
"inputs_embeds": inputs_embeds,
}
ret_hidden_states = self.model( if self.pass_hidden_states_to_model:
input_ids=model_input_ids, model_hidden_states = self.hidden_states[:num_input_tokens]
positions=model_positions, model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions)
hidden_states=model_hidden_states, model_kwargs["hidden_states"] = model_hidden_states
inputs_embeds=inputs_embeds, if self.method == "mtp":
) model_kwargs["positions"] = model_positions
if self.method == "mtp":
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states hidden_states = last_hidden_states
else: else:
@@ -722,6 +779,7 @@ class AscendEagleProposer(EagleProposer):
last_hidden_states, model_positions, hidden_states last_hidden_states, model_positions, hidden_states
) )
num_indices = token_indices_to_sample.shape[0]
if self.pcp_size > 1: if self.pcp_size > 1:
# remove graph padding before all_gather # remove graph padding before all_gather
hidden_states = hidden_states[:num_tokens] hidden_states = hidden_states[:num_tokens]
@@ -741,26 +799,27 @@ class AscendEagleProposer(EagleProposer):
self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]], self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]],
) )
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable() and not is_dummy: if lmhead_tp_enable() and not is_dummy:
max_num_reqs_across_dp = ( max_num_reqs_across_dp = (
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
) )
last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices)) token_indices_to_sample = nn.functional.pad(
token_indices_to_sample, (0, max_num_reqs_across_dp - num_indices)
)
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[token_indices_to_sample]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
logits = logits[:num_indices] logits = logits[:num_indices]
last_token_indices = last_token_indices[:num_indices] token_indices_to_sample = token_indices_to_sample[:num_indices]
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1 or self.parallel_drafting:
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, self.num_speculative_tokens)
if self.pcp_size * self.dcp_size > 1 and is_prefill: if self.pcp_size * self.dcp_size > 1 and is_prefill:
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
@@ -775,11 +834,11 @@ class AscendEagleProposer(EagleProposer):
) )
draft_token_ids_tensor[0] = draft_token_ids draft_token_ids_tensor[0] = draft_token_ids
if self.uses_mrope: if self.uses_mrope:
positions = target_positions[:, last_token_indices] positions = self.mrope_positions[:, token_indices_to_sample]
else: else:
positions = target_positions[last_token_indices] positions = self.positions[token_indices_to_sample]
hidden_states = hidden_states[last_token_indices] hidden_states = hidden_states[token_indices_to_sample]
last_token_indices = self.arange[:batch_size] token_indices_to_sample = self.arange[:batch_size]
input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size
@@ -843,13 +902,17 @@ class AscendEagleProposer(EagleProposer):
forward_context.attn_metadata = ( forward_context.attn_metadata = (
multi_steps_attn_metadata[draft_step + 1] if multi_steps_attn_metadata else None multi_steps_attn_metadata[draft_step + 1] if multi_steps_attn_metadata else None
) )
ret_hidden_states = self.model(
input_ids=model_input_ids, model_kwargs = {
positions=model_positions, "input_ids": model_input_ids,
hidden_states=model_hidden_states, "positions": model_positions,
inputs_embeds=inputs_embeds, "inputs_embeds": inputs_embeds,
) }
if self.method == "mtp": if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = model_hidden_states
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states hidden_states = last_hidden_states
else: else:
@@ -859,22 +922,22 @@ class AscendEagleProposer(EagleProposer):
last_hidden_states, model_positions, hidden_states last_hidden_states, model_positions, hidden_states
) )
num_indices = last_token_indices.shape[0] num_indices = token_indices_to_sample.shape[0]
if lmhead_tp_enable() and not is_dummy: if lmhead_tp_enable() and not is_dummy:
max_num_reqs_across_dp = ( max_num_reqs_across_dp = (
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
) )
last_token_indices = nn.functional.pad( token_indices_to_sample = nn.functional.pad(
last_token_indices, token_indices_to_sample,
(0, max_num_reqs_across_dp - num_indices), (0, max_num_reqs_across_dp - num_indices),
) )
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[token_indices_to_sample]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
logits = logits[:num_indices] logits = logits[:num_indices]
last_token_indices = last_token_indices[:num_indices] token_indices_to_sample = token_indices_to_sample[:num_indices]
# TODO(wenlong): get more than one token for tree attention # TODO(wenlong): get more than one token for tree attention
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
@@ -885,6 +948,122 @@ class AscendEagleProposer(EagleProposer):
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
return draft_token_ids return draft_token_ids
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
if not self.needs_extra_input_slots:
# Default EAGLE pathway: no reshaping of input tensors needed.
# Simply rotate the input ids and leave the positions unchanged,
# Inserting the next token ids at the last slot in each request.
if token_indices_to_sample is None:
token_indices_to_sample = cad.query_start_loc[1:] - 1
num_tokens = target_token_ids.shape[0]
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[token_indices_to_sample] = next_token_ids
# copy inputs to buffer for cudagraph
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
target_positions = target_positions[0]
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
return num_tokens, token_indices_to_sample, cad
else:
assert self.is_rejected_token_mask is not None
assert self.is_masked_token_mask is not None
# 1.
# Call the CopyAndExpandEagleInputs AscendC operator to copy
# input_ids and positions into the correct slots in the
# preallocated buffers self.input_ids, self.positions.
batch_size = cad.batch_size()
total_num_input_tokens = target_token_ids.shape[0]
total_num_output_tokens = total_num_input_tokens + (self.net_num_new_slots_per_request * batch_size)
query_start_loc = cad.query_start_loc
query_end_loc = cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
query_end_loc = query_end_loc - num_rejected_tokens_gpu
(
out_input_ids,
out_positions,
out_is_rejected_token_mask,
out_is_masked_token_mask,
token_indices_to_sample,
out_hidden_state_mapping,
) = torch.ops._C_ascend.npu_copy_and_expand_eagle_inputs(
target_token_ids,
target_positions.to(torch.int32),
next_token_ids,
query_start_loc,
query_end_loc,
0, # padding_token_id
self.parallel_drafting_token_id,
self.extra_slots_per_request,
self.pass_hidden_states_to_model,
total_num_output_tokens,
)
# Copy returned tensors into pre-allocated buffers
self.input_ids[:total_num_output_tokens].copy_(out_input_ids)
self.positions[:total_num_output_tokens].copy_(out_positions)
self.is_rejected_token_mask[:total_num_output_tokens].copy_(out_is_rejected_token_mask)
self.is_masked_token_mask[:total_num_output_tokens].copy_(out_is_masked_token_mask)
if self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.hidden_states[out_hidden_state_mapping] = target_hidden_states
# Use torch.where to avoid DtoH sync from boolean indexing
mask = self.is_masked_token_mask[:total_num_output_tokens]
torch.where(
mask.unsqueeze(1),
self.parallel_drafting_hidden_state_tensor,
self.hidden_states[:total_num_output_tokens],
out=self.hidden_states[:total_num_output_tokens],
)
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder = (
self._get_attention_metadata_builder()
if self.attn_metadata_builder is None
else self.attn_metadata_builder
)
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens],
block_size=builder.kv_cache_spec.block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
# 3. Update the common attention metadata with the new (meta)data
new_cad = extend_all_queries_by_N(
cad,
N=self.net_num_new_slots_per_request,
arange=self.arange,
new_slot_mapping=new_slot_mapping,
)
return total_num_output_tokens, token_indices_to_sample, new_cad
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model")
def attn_update_stack_num_spec_norm( def attn_update_stack_num_spec_norm(
self, self,
# `draft_step` must start from `1`, no `0` # `draft_step` must start from `1`, no `0`
@@ -1201,7 +1380,7 @@ class AscendEagleProposer(EagleProposer):
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata, spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor, valid_sampled_tokens_count: torch.Tensor,
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
This function is used to prepare the inputs for speculative decoding This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding, It updates the common_attn_metadata for speculative decoding,
@@ -1215,7 +1394,7 @@ class AscendEagleProposer(EagleProposer):
device = valid_sampled_tokens_count.device device = valid_sampled_tokens_count.device
token_indices_to_sample = torch.empty((num_reqs,), dtype=torch.int32, device=device) token_indices_to_sample = torch.empty((num_reqs,), dtype=torch.int32, device=device)
num_rejected_tokens_gpu = torch.empty((num_reqs,), dtype=torch.int32, device=device)
num_blocks_needed = triton.cdiv(num_reqs, _PREPARE_INPUTS_BLOCK_SIZE) num_blocks_needed = triton.cdiv(num_reqs, _PREPARE_INPUTS_BLOCK_SIZE)
num_vector_core = get_vectorcore_num() num_vector_core = get_vectorcore_num()
grid_size = min(num_blocks_needed, num_vector_core) grid_size = min(num_blocks_needed, num_vector_core)
@@ -1226,6 +1405,7 @@ class AscendEagleProposer(EagleProposer):
valid_sampled_tokens_count, valid_sampled_tokens_count,
common_attn_metadata.query_start_loc, common_attn_metadata.query_start_loc,
token_indices_to_sample, token_indices_to_sample,
num_rejected_tokens_gpu,
num_reqs, num_reqs,
BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE, BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE,
) )
@@ -1274,7 +1454,7 @@ class AscendEagleProposer(EagleProposer):
max_seq_len=0, max_seq_len=0,
) )
return spec_common_attn_metadata, token_indices, token_indices_to_sample return spec_common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu
def _split_pcp_input(self, req_scheduled_tokens, input_ids, target_hidden_states): def _split_pcp_input(self, req_scheduled_tokens, input_ids, target_hidden_states):
""" """
@@ -1394,3 +1574,18 @@ class AscendEagleProposer(EagleProposer):
if hidden_states is not None: if hidden_states is not None:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), True) hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), True)
return last_hidden_states, positions, hidden_states return last_hidden_states, positions, hidden_states
class AscendEagleProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config,
device,
pass_hidden_states_to_model=True,
runner=runner,
)

View File

@@ -108,6 +108,7 @@ from vllm_ascend.patch.worker.patch_draft_quarot import patch_load_weights
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.sample.sampler import AscendSampler
from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
@@ -406,7 +407,12 @@ class NPUModelRunner(GPUModelRunner):
def _set_up_drafter(self): def _set_up_drafter(self):
# Set up speculative decoding. # Set up speculative decoding.
self.drafter: ( self.drafter: (
AscendNgramProposer | AscendEagleProposer | AscendSuffixDecodingProposer | AscendMedusaProposer | None AscendNgramProposer
| AscendEagleProposer
| AscendDraftModelProposer
| AscendSuffixDecodingProposer
| AscendMedusaProposer
| None
) = None ) = None
self.actual_seq_lengths_q: list[int] = [] self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1 self.decode_token_per_req = 1
@@ -971,7 +977,7 @@ class NPUModelRunner(GPUModelRunner):
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
) )
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model():
common_attn_metadata = spec_decode_common_attn_metadata common_attn_metadata = spec_decode_common_attn_metadata
sampled_token_ids = valid_sampled_token_ids sampled_token_ids = valid_sampled_token_ids
@@ -1018,6 +1024,8 @@ class NPUModelRunner(GPUModelRunner):
long_seq_metadata = None # type: ignore long_seq_metadata = None # type: ignore
num_prefill_reqs = 0 num_prefill_reqs = 0
num_decode_reqs = 0 num_decode_reqs = 0
num_rejected_tokens_gpu = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
# update pcp related params # update pcp related params
if self.pcp_size > 1: if self.pcp_size > 1:
@@ -1053,8 +1061,10 @@ class NPUModelRunner(GPUModelRunner):
) )
else: else:
assert self.drafter is not None assert self.drafter is not None
common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded( common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = (
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count self.drafter.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
) )
if self.pcp_size > 1: if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices] target_token_ids = input_ids_pcp_full[token_indices]
@@ -1075,7 +1085,7 @@ class NPUModelRunner(GPUModelRunner):
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample, token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens, req_scheduled_tokens=req_scheduled_tokens,
@@ -1084,6 +1094,7 @@ class NPUModelRunner(GPUModelRunner):
num_decode_reqs=num_decode_reqs, num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output, scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
) )
else: else:
raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}") raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
@@ -1516,16 +1527,16 @@ class NPUModelRunner(GPUModelRunner):
with record_function_or_nullcontext("draft_token"): with record_function_or_nullcontext("draft_token"):
if self.speculative_config: if self.speculative_config:
use_padded_batch_for_eagle = ( use_padded_batch = (
self.speculative_config self.speculative_config
and self.speculative_config.use_eagle() and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model())
and not self.speculative_config.disable_padded_drafter_batch and not self.speculative_config.disable_padded_drafter_batch
) )
if use_padded_batch_for_eagle: if use_padded_batch:
# EAGLE speculative decoding can use the GPU sampled tokens # EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish. # as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids) propose_draft_token_ids(sampler_output.sampled_token_ids)
if self.speculative_config and not use_padded_batch_for_eagle: if self.speculative_config and not use_padded_batch:
# ngram and other speculative decoding methods use the sampled # ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping. # tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids) propose_draft_token_ids(valid_sampled_token_ids)
@@ -2165,7 +2176,7 @@ class NPUModelRunner(GPUModelRunner):
if kv_cache_gid > 0: if kv_cache_gid > 0:
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid) cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
if self.speculative_config and spec_decode_common_attn_metadata is None: if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, AscendEagleProposer): if isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
else: else: