[feat][spec decode]Unified draft parallel (#6766)
### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078
### How was this patch tested?
run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811
benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
--temperature 0 \
--model /model/Llama-3.1-8B-Instruct \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path philschmid/mt-bench \
--num-prompts ${NUM_PROMPTS} \
--max-concurrency ${MAX_CONCURRENCY} \
--seed 1234
test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.
----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
cc @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
@@ -25,7 +25,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
|
||||
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;causal_conv1d;"
|
||||
CUSTOM_OPS="moe_grouped_matmul;grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;copy_and_expand_eagle_inputs;causal_conv1d;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -64,6 +64,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
"add_rms_norm_bias"
|
||||
"apply_top_k_top_p_custom"
|
||||
"transpose_kv_cache_by_block"
|
||||
"copy_and_expand_eagle_inputs"
|
||||
"causal_conv1d"
|
||||
"moe_grouped_matmul"
|
||||
)
|
||||
|
||||
22
csrc/copy_and_expand_eagle_inputs/op_host/CMakeLists.txt
Normal file
22
csrc/copy_and_expand_eagle_inputs/op_host/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
add_ops_compile_options(
|
||||
OP_NAME CopyAndExpandEagleInputs
|
||||
OPTIONS --cce-auto-sync=on
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
copy_and_expand_eagle_inputs_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
copy_and_expand_eagle_inputs_tiling.cpp
|
||||
)
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
copy_and_expand_eagle_inputs_infershape.cpp
|
||||
)
|
||||
@@ -0,0 +1,87 @@
|
||||
/**
|
||||
* @file copy_and_expand_eagle_inputs_def.cpp
|
||||
* @brief CopyAndExpandEagleInputs OpDef registration
|
||||
*/
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
|
||||
class CopyAndExpandEagleInputs : public OpDef {
|
||||
public:
|
||||
explicit CopyAndExpandEagleInputs(const char* name) : OpDef(name)
|
||||
{
|
||||
// -------------------- Inputs --------------------
|
||||
this->Input("target_token_ids")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Input("target_positions")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Input("next_token_ids")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Input("query_start_loc")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Input("query_end_loc")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
|
||||
// -------------------- Outputs --------------------
|
||||
this->Output("out_input_ids")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("out_positions")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("out_is_rejected_token_mask")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("out_is_masked_token_mask")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT8})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("out_new_token_indices")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
this->Output("out_hidden_state_mapping")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT32})
|
||||
.Format({ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND});
|
||||
|
||||
// -------------------- Attributes --------------------
|
||||
this->Attr("padding_token_id").Int();
|
||||
this->Attr("parallel_drafting_token_id").Int();
|
||||
this->Attr("num_padding_slots_per_request").Int();
|
||||
this->Attr("shift_input_ids").Bool();
|
||||
this->Attr("total_input_tokens").Int();
|
||||
|
||||
// -------------------- Platform --------------------
|
||||
this->AICore().AddConfig("ascend910b");
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(CopyAndExpandEagleInputs);
|
||||
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,107 @@
|
||||
/**
|
||||
* @file copy_and_expand_eagle_inputs_infershape.cpp
|
||||
* @brief InferShape and InferDataType for CopyAndExpandEagleInputs
|
||||
*/
|
||||
|
||||
#include "register/op_def_registry.h"
|
||||
#include "log/ops_log.h"
|
||||
|
||||
#define unlikely(x) __builtin_expect((x), 0)
|
||||
#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \
|
||||
do { \
|
||||
if (unlikely((ptr) == nullptr)) { \
|
||||
const char* name = (unlikely(((context) == nullptr) || (context)->GetNodeName() == nullptr)) ? \
|
||||
"nil" : \
|
||||
(context)->GetNodeName(); \
|
||||
OPS_LOG_E(name, "%s is nullptr!", #ptr); \
|
||||
return ge::GRAPH_FAILED; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static constexpr int IDX_TARGET_TOKEN_IDS = 0;
|
||||
static constexpr int IDX_TARGET_POSITIONS = 1;
|
||||
static constexpr int IDX_NEXT_TOKEN_IDS = 2;
|
||||
static constexpr int IDX_QUERY_START_LOC = 3;
|
||||
static constexpr int IDX_QUERY_END_LOC = 4;
|
||||
|
||||
static constexpr int OUT_INPUT_IDS = 0;
|
||||
static constexpr int OUT_POSITIONS = 1;
|
||||
static constexpr int OUT_REJECTED_MASK = 2;
|
||||
static constexpr int OUT_MASKED_MASK = 3;
|
||||
static constexpr int OUT_NEW_TOKEN_INDICES = 4;
|
||||
static constexpr int OUT_HIDDEN_STATE_MAPPING = 5;
|
||||
static constexpr int OUTPUT_NUM = 6;
|
||||
|
||||
static constexpr int ATTR_NUM_PADDING_SLOTS = 2;
|
||||
static constexpr int ATTR_TOTAL_INPUT_TOKENS = 4;
|
||||
|
||||
using namespace ge;
|
||||
|
||||
namespace ops {
|
||||
|
||||
static ge::graphStatus InferShape4CopyAndExpandEagleInputs(gert::InferShapeContext* context)
|
||||
{
|
||||
// Get input shapes
|
||||
const gert::Shape* targetTokenIdsShape = context->GetInputShape(IDX_TARGET_TOKEN_IDS);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, targetTokenIdsShape);
|
||||
const gert::Shape* queryStartLocShape = context->GetInputShape(IDX_QUERY_START_LOC);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, queryStartLocShape);
|
||||
|
||||
// Derive dimensions from input shapes
|
||||
int64_t totalInputTokens = targetTokenIdsShape->GetDim(0);
|
||||
int64_t numReqs = queryStartLocShape->GetDim(0) - 1;
|
||||
|
||||
// Get num_padding_slots_per_request from attributes
|
||||
auto attrs = context->GetAttrs();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
||||
int64_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer<int64_t>(ATTR_NUM_PADDING_SLOTS));
|
||||
|
||||
// Compute total_draft_tokens = total_input_tokens + (num_padding_slots_per_request - 1) * num_reqs
|
||||
int64_t totalDraftTokens = totalInputTokens + (numPaddingSlotsPerReq - 1) * numReqs;
|
||||
|
||||
// Get and validate all output shapes
|
||||
gert::Shape* outShapes[OUTPUT_NUM];
|
||||
for (int i = 0; i < OUTPUT_NUM; ++i) {
|
||||
outShapes[i] = context->GetOutputShape(i);
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, outShapes[i]);
|
||||
outShapes[i]->SetDimNum(1);
|
||||
}
|
||||
|
||||
// out_input_ids, out_positions, out_rejected_mask, out_masked_mask: [total_draft_tokens]
|
||||
outShapes[OUT_INPUT_IDS]->SetDim(0, totalDraftTokens);
|
||||
outShapes[OUT_POSITIONS]->SetDim(0, totalDraftTokens);
|
||||
outShapes[OUT_REJECTED_MASK]->SetDim(0, totalDraftTokens);
|
||||
outShapes[OUT_MASKED_MASK]->SetDim(0, totalDraftTokens);
|
||||
|
||||
// out_new_token_indices: [num_reqs * num_padding_slots_per_request]
|
||||
outShapes[OUT_NEW_TOKEN_INDICES]->SetDim(0, numReqs * numPaddingSlotsPerReq);
|
||||
|
||||
// out_hidden_state_mapping: [total_input_tokens]
|
||||
outShapes[OUT_HIDDEN_STATE_MAPPING]->SetDim(0, totalInputTokens);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus InferDataType4CopyAndExpandEagleInputs(gert::InferDataTypeContext* context)
|
||||
{
|
||||
// out_input_ids: INT32
|
||||
context->SetOutputDataType(OUT_INPUT_IDS, DT_INT32);
|
||||
// out_positions: INT32
|
||||
context->SetOutputDataType(OUT_POSITIONS, DT_INT32);
|
||||
// out_is_rejected_token_mask: INT8
|
||||
context->SetOutputDataType(OUT_REJECTED_MASK, DT_INT8);
|
||||
// out_is_masked_token_mask: INT8
|
||||
context->SetOutputDataType(OUT_MASKED_MASK, DT_INT8);
|
||||
// out_new_token_indices: INT32
|
||||
context->SetOutputDataType(OUT_NEW_TOKEN_INDICES, DT_INT32);
|
||||
// out_hidden_state_mapping: INT32
|
||||
context->SetOutputDataType(OUT_HIDDEN_STATE_MAPPING, DT_INT32);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(CopyAndExpandEagleInputs)
|
||||
.InferShape(InferShape4CopyAndExpandEagleInputs)
|
||||
.InferDataType(InferDataType4CopyAndExpandEagleInputs);
|
||||
|
||||
} // namespace ops
|
||||
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* @file copy_and_expand_eagle_inputs_tiling.cpp
|
||||
* @brief CopyAndExpandEagleInputs TilingFunc implementation
|
||||
*/
|
||||
|
||||
#include "copy_and_expand_eagle_inputs_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "log/ops_log.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace optiling {
|
||||
|
||||
static void GetCompileParameters(
|
||||
gert::TilingContext* context, uint32_t& coreNum)
|
||||
{
|
||||
auto ptrCompileInfo = reinterpret_cast<const CopyAndExpandEagleInputsCompileInfo*>(context->GetCompileInfo());
|
||||
if (ptrCompileInfo == nullptr) {
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
coreNum = ascendcPlatform.GetCoreNum();
|
||||
} else {
|
||||
coreNum = ptrCompileInfo->totalCoreNum;
|
||||
}
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingFunc(gert::TilingContext* context)
|
||||
{
|
||||
OPS_LOG_I(context, "Enter TilingFunc for CopyAndExpandEagleInputs");
|
||||
OPS_LOG_D(context, "TilingFunc running.");
|
||||
|
||||
// ========== 1. Get hardware core count ==========
|
||||
uint32_t coreNum;
|
||||
GetCompileParameters(context, coreNum);
|
||||
|
||||
// ========== 2. Derive num_reqs from query_start_loc shape ==========
|
||||
// query_start_loc is the 4th input (index 3), shape [num_reqs + 1]
|
||||
auto queryStartLocShape = context->GetInputShape(3);
|
||||
uint32_t numReqs = 0;
|
||||
if (queryStartLocShape != nullptr &&
|
||||
queryStartLocShape->GetStorageShape().GetDimNum() > 0) {
|
||||
int64_t dim0 = queryStartLocShape->GetStorageShape().GetDim(0);
|
||||
numReqs = (dim0 > 1) ? static_cast<uint32_t>(dim0 - 1) : 0;
|
||||
}
|
||||
|
||||
// ========== 3. Get operator attributes ==========
|
||||
auto attrs = context->GetAttrs();
|
||||
|
||||
int32_t paddingTokenId = *(attrs->GetAttrPointer<int32_t>(0));
|
||||
int32_t parallelDraftingTokenId = *(attrs->GetAttrPointer<int32_t>(1));
|
||||
int32_t numPaddingSlotsPerReq = *(attrs->GetAttrPointer<int32_t>(2));
|
||||
bool shiftInputIds = *(attrs->GetAttrPointer<bool>(3));
|
||||
int32_t totalInputTokens = *(attrs->GetAttrPointer<int32_t>(4));
|
||||
|
||||
// ========== 4. Compute core distribution ==========
|
||||
uint32_t usedCoreNum = std::min(coreNum, numReqs);
|
||||
if (usedCoreNum == 0) {
|
||||
usedCoreNum = 1;
|
||||
}
|
||||
uint32_t reqsPerCore = numReqs / usedCoreNum;
|
||||
uint32_t remainderReqs = numReqs % usedCoreNum;
|
||||
|
||||
// ========== 5. Set tiling_key ==========
|
||||
context->SetTilingKey(1);
|
||||
|
||||
// ========== 6. Get output shape ==========
|
||||
uint32_t totalDraftTokens = 0;
|
||||
auto outShape = context->GetOutputShape(0);
|
||||
if (outShape != nullptr &&
|
||||
outShape->GetStorageShape().GetDimNum() > 0) {
|
||||
totalDraftTokens = static_cast<uint32_t>(outShape->GetStorageShape().GetDim(0));
|
||||
}
|
||||
|
||||
// ========== 7. Fill TilingData ==========
|
||||
CopyAndExpandEagleInputsTilingData tiling;
|
||||
tiling.set_usedCoreNum(usedCoreNum);
|
||||
tiling.set_numReqs(numReqs);
|
||||
tiling.set_reqsPerCore(reqsPerCore);
|
||||
tiling.set_remainderReqs(remainderReqs);
|
||||
tiling.set_paddingTokenId(paddingTokenId);
|
||||
tiling.set_parallelDraftingTokenId(parallelDraftingTokenId);
|
||||
tiling.set_numPaddingSlotsPerReq(static_cast<uint32_t>(numPaddingSlotsPerReq));
|
||||
tiling.set_totalInputTokens(static_cast<uint32_t>(totalInputTokens));
|
||||
tiling.set_shiftInputIds(shiftInputIds ? 1u : 0u);
|
||||
tiling.set_totalDraftTokens(totalDraftTokens);
|
||||
|
||||
tiling.SaveToBuffer(
|
||||
context->GetRawTilingData()->GetData(),
|
||||
context->GetRawTilingData()->GetCapacity());
|
||||
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
|
||||
|
||||
// ========== 8. Set block_dim ==========
|
||||
context->SetBlockDim(usedCoreNum);
|
||||
|
||||
OPS_LOG_I(context, "Block Dim: %u", usedCoreNum);
|
||||
OPS_LOG_I(context,
|
||||
"numReqs: %u, reqsPerCore: %u, remainderReqs: %u, totalInputTokens: %d, totalDraftTokens: %u",
|
||||
numReqs, reqsPerCore, remainderReqs, totalInputTokens, totalDraftTokens);
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus TilingPrepare4CopyAndExpandEagleInputs(gert::TilingParseContext* context)
|
||||
{
|
||||
OPS_LOG_D(context, "TilingPrepare4CopyAndExpandEagleInputs running.");
|
||||
OPS_LOG_I(context, "TilingPrepare4CopyAndExpandEagleInputs running.");
|
||||
auto compileInfo = context->GetCompiledInfo<CopyAndExpandEagleInputsCompileInfo>();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo);
|
||||
auto platformInfo = context->GetPlatformInfo();
|
||||
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo);
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
|
||||
|
||||
compileInfo->totalCoreNum = ascendcPlatform.GetCoreNum();
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(CopyAndExpandEagleInputs)
|
||||
.Tiling(TilingFunc)
|
||||
.TilingParse<CopyAndExpandEagleInputsCompileInfo>(TilingPrepare4CopyAndExpandEagleInputs);
|
||||
|
||||
} // namespace optiling
|
||||
@@ -0,0 +1,37 @@
|
||||
#ifndef COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H
|
||||
#define COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H
|
||||
|
||||
#include "register/tilingdata_base.h"
|
||||
#include "error_log.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
|
||||
namespace optiling {
|
||||
|
||||
BEGIN_TILING_DATA_DEF(CopyAndExpandEagleInputsTilingData)
|
||||
// ---- 分核参数 ----
|
||||
TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum); // 实际使用的核数
|
||||
TILING_DATA_FIELD_DEF(uint32_t, numReqs); // 总请求数
|
||||
TILING_DATA_FIELD_DEF(uint32_t, reqsPerCore); // 每核基础请求数
|
||||
TILING_DATA_FIELD_DEF(uint32_t, remainderReqs); // 余数(前 remainder 个核多处理 1 个请求)
|
||||
|
||||
// ---- 算子属性 ----
|
||||
TILING_DATA_FIELD_DEF(int32_t, paddingTokenId); // 填充 token ID
|
||||
TILING_DATA_FIELD_DEF(int32_t, parallelDraftingTokenId); // 并行推测解码 token ID
|
||||
TILING_DATA_FIELD_DEF(uint32_t, numPaddingSlotsPerReq); // 每个请求的 padding 槽位数
|
||||
TILING_DATA_FIELD_DEF(uint32_t, totalInputTokens); // 输入 token 总数(用于 clamp)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, shiftInputIds); // 0 = false, 1 = true
|
||||
|
||||
// ---- 输出尺寸 ----
|
||||
TILING_DATA_FIELD_DEF(uint32_t, totalDraftTokens); // 输出 token 总数
|
||||
END_TILING_DATA_DEF;
|
||||
|
||||
struct CopyAndExpandEagleInputsCompileInfo {
|
||||
uint32_t totalCoreNum = 0;
|
||||
};
|
||||
|
||||
REGISTER_TILING_DATA_CLASS(CopyAndExpandEagleInputs, CopyAndExpandEagleInputsTilingData)
|
||||
|
||||
} // namespace optiling
|
||||
|
||||
#endif // COPY_AND_EXPAND_EAGLE_INPUTS_TILING_H
|
||||
@@ -0,0 +1,386 @@
|
||||
/**
|
||||
* CopyAndExpandEagleInputs 算子 Kernel 实现 (DataCopy 版)
|
||||
*
|
||||
* 多核策略:
|
||||
* 所有 GM 读写通过 DataCopy 完成(不使用 GlobalTensor::SetValue/GetValue 访问 GM)。
|
||||
* UB (LocalTensor) 上使用 SetValue/GetValue 构建数据,再 DataCopy 到 GM。
|
||||
* 对齐处理参考 CANN 内置算子的 DataCopyCustom 模式。
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
// ONE_BLK_SIZE comes from AscendC namespace (32 bytes per block)
|
||||
|
||||
class CopyAndExpandEagleInputsKernel {
|
||||
public:
|
||||
__aicore__ inline CopyAndExpandEagleInputsKernel() {}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR targetTokenIds, GM_ADDR targetPositions,
|
||||
GM_ADDR nextTokenIds, GM_ADDR queryStartLoc,
|
||||
GM_ADDR queryEndLoc,
|
||||
GM_ADDR outInputIds, GM_ADDR outPositions,
|
||||
GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask,
|
||||
GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping,
|
||||
const CopyAndExpandEagleInputsTilingData* tilingData)
|
||||
{
|
||||
usedCoreNum = tilingData->usedCoreNum;
|
||||
numReqs = tilingData->numReqs;
|
||||
reqsPerCore = tilingData->reqsPerCore;
|
||||
remainderReqs = tilingData->remainderReqs;
|
||||
paddingTokenId = tilingData->paddingTokenId;
|
||||
parallelDraftingTokenId = tilingData->parallelDraftingTokenId;
|
||||
numPaddingSlotsPerReq = tilingData->numPaddingSlotsPerReq;
|
||||
totalInputTokens = tilingData->totalInputTokens;
|
||||
totalDraftTokens = tilingData->totalDraftTokens;
|
||||
|
||||
uint32_t coreId = GetBlockIdx();
|
||||
if (coreId < remainderReqs) {
|
||||
myStartReq = coreId * (reqsPerCore + 1);
|
||||
myNumReqs = reqsPerCore + 1;
|
||||
} else {
|
||||
myStartReq = remainderReqs * (reqsPerCore + 1) + (coreId - remainderReqs) * reqsPerCore;
|
||||
myNumReqs = reqsPerCore;
|
||||
}
|
||||
|
||||
// 绑定 GM Tensor
|
||||
gmTargetTokenIds.SetGlobalBuffer((__gm__ int32_t*)targetTokenIds, totalInputTokens);
|
||||
gmTargetPositions.SetGlobalBuffer((__gm__ int32_t*)targetPositions, totalInputTokens);
|
||||
gmNextTokenIds.SetGlobalBuffer((__gm__ int32_t*)nextTokenIds, numReqs);
|
||||
gmQueryStartLoc.SetGlobalBuffer((__gm__ int32_t*)queryStartLoc, numReqs + 1);
|
||||
gmQueryEndLoc.SetGlobalBuffer((__gm__ int32_t*)queryEndLoc, numReqs);
|
||||
gmOutInputIds.SetGlobalBuffer((__gm__ int32_t*)outInputIds, totalDraftTokens);
|
||||
gmOutPositions.SetGlobalBuffer((__gm__ int32_t*)outPositions, totalDraftTokens);
|
||||
gmOutIsRejectedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsRejectedTokenMask, totalDraftTokens);
|
||||
gmOutIsMaskedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsMaskedTokenMask, totalDraftTokens);
|
||||
gmOutNewTokenIndices.SetGlobalBuffer((__gm__ int32_t*)outNewTokenIndices, numPaddingSlotsPerReq * numReqs);
|
||||
gmOutHiddenStateMapping.SetGlobalBuffer((__gm__ int32_t*)outHiddenStateMapping, totalInputTokens);
|
||||
|
||||
// 分配 UB 缓冲区 —— 每个 TBuf 的基地址自动 32 字节对齐
|
||||
// 元数据各自独立 TBuf,避免 UB 地址不对齐
|
||||
uint32_t metaAligned = AlignUp((myNumReqs + 1) * sizeof(int32_t), ONE_BLK_SIZE);
|
||||
pipe.InitBuffer(qsBuf, metaAligned);
|
||||
pipe.InitBuffer(qeBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(ntBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
|
||||
// I/O 缓冲区
|
||||
constexpr uint32_t MAX_PER_REQ = 4096;
|
||||
pipe.InitBuffer(inputBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outIdsBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outPosBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outRejBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outMskBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(ntiBuf, AlignUp(64 * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(hsmBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
|
||||
// DataCopy 元数据到各自 UB
|
||||
if (myNumReqs > 0) {
|
||||
LocalTensor<int32_t> lqs = qsBuf.Get<int32_t>();
|
||||
DataCopyIn(lqs, gmQueryStartLoc, (int32_t)myStartReq, (int32_t)(myNumReqs + 1));
|
||||
|
||||
LocalTensor<int32_t> lqe = qeBuf.Get<int32_t>();
|
||||
DataCopyIn(lqe, gmQueryEndLoc, (int32_t)myStartReq, (int32_t)myNumReqs);
|
||||
|
||||
LocalTensor<int32_t> lnt = ntBuf.Get<int32_t>();
|
||||
DataCopyIn(lnt, gmNextTokenIds, (int32_t)myStartReq, (int32_t)myNumReqs);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessShiftFalse()
|
||||
{
|
||||
for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) {
|
||||
ProcessOneRequestShiftFalse(myStartReq + rLocal, rLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessShiftTrue()
|
||||
{
|
||||
for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) {
|
||||
ProcessOneRequestShiftTrue(myStartReq + rLocal, rLocal);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// ============================================================
|
||||
// AlignUp 辅助
|
||||
// ============================================================
|
||||
static __aicore__ inline uint32_t AlignUp(uint32_t x, uint32_t a)
|
||||
{
|
||||
return (x + a - 1) / a * a;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GM → UB: 标准 DataCopy,count 自动 round-up 到 block 对齐
|
||||
// 多读的元素在 UB 中不会被使用,安全无害
|
||||
// ============================================================
|
||||
__aicore__ inline void DataCopyIn(LocalTensor<int32_t>& dst,
|
||||
GlobalTensor<int32_t>& src,
|
||||
int32_t gmOffset, int32_t count)
|
||||
{
|
||||
if (count <= 0) return;
|
||||
constexpr int32_t ELEMS_PER_BLK = ONE_BLK_SIZE / (int32_t)sizeof(int32_t); // 8
|
||||
int32_t aligned = (count + ELEMS_PER_BLK - 1) / ELEMS_PER_BLK * ELEMS_PER_BLK;
|
||||
DataCopy(dst, src[gmOffset], aligned);
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// UB → GM: DataCopyPad + DataCopyExtParams(C220 支持任意字节数)
|
||||
// 精确写入 count 个元素,不越界覆盖相邻数据
|
||||
// ============================================================
|
||||
__aicore__ inline void DataCopyOut_int32(GlobalTensor<int32_t>& dst,
|
||||
LocalTensor<int32_t>& src,
|
||||
int32_t gmOffset, int32_t count)
|
||||
{
|
||||
if (count <= 0) return;
|
||||
uint32_t totalBytes = static_cast<uint32_t>(count) * static_cast<uint32_t>(sizeof(int32_t));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
|
||||
__aicore__ inline void DataCopyOut_int8(GlobalTensor<int8_t>& dst,
|
||||
LocalTensor<int8_t>& src,
|
||||
int32_t gmOffset, int32_t count)
|
||||
{
|
||||
if (count <= 0) return;
|
||||
uint32_t totalBytes = static_cast<uint32_t>(count) * static_cast<uint32_t>(sizeof(int8_t));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 元数据读取 (从各自 UB 缓冲区)
|
||||
// ============================================================
|
||||
__aicore__ inline int32_t ReadQS(uint32_t rLocal) {
|
||||
return qsBuf.Get<int32_t>().GetValue(rLocal);
|
||||
}
|
||||
__aicore__ inline int32_t ReadNextQS(uint32_t rLocal) {
|
||||
return qsBuf.Get<int32_t>().GetValue(rLocal + 1);
|
||||
}
|
||||
__aicore__ inline int32_t ReadQE(uint32_t rLocal) {
|
||||
return qeBuf.Get<int32_t>().GetValue(rLocal);
|
||||
}
|
||||
__aicore__ inline int32_t ReadNT(uint32_t rLocal) {
|
||||
return ntBuf.Get<int32_t>().GetValue(rLocal);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// shift_input_ids = false
|
||||
// ============================================================
|
||||
__aicore__ inline void ProcessOneRequestShiftFalse(uint32_t r, uint32_t rLocal)
|
||||
{
|
||||
int32_t queryStart = ReadQS(rLocal);
|
||||
int32_t nextQueryStart = ReadNextQS(rLocal);
|
||||
int32_t queryEnd = ReadQE(rLocal);
|
||||
|
||||
int32_t numRejected = nextQueryStart - queryEnd - 1;
|
||||
if (numRejected < 0) numRejected = 0;
|
||||
int32_t numValid = queryEnd - queryStart + 1;
|
||||
if (numValid < 0) numValid = 0;
|
||||
|
||||
int32_t outputStart = queryStart + (int32_t)r * (int32_t)numPaddingSlotsPerReq;
|
||||
int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected;
|
||||
|
||||
// 读取输入 token 到 UB
|
||||
int32_t numInputTokensForReq = nextQueryStart - queryStart;
|
||||
LocalTensor<int32_t> localInput = inputBuf.Get<int32_t>();
|
||||
if (numInputTokensForReq > 0) {
|
||||
DataCopyIn(localInput, gmTargetTokenIds, queryStart, numInputTokensForReq);
|
||||
}
|
||||
|
||||
// 读取起始 position
|
||||
LocalTensor<int32_t> localTmpPos = hsmBuf.Get<int32_t>();
|
||||
DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1);
|
||||
int32_t startPos = localTmpPos.GetValue(0);
|
||||
|
||||
int32_t nextTokenId = ReadNT(rLocal);
|
||||
|
||||
// 构建输出到 UB
|
||||
LocalTensor<int32_t> lIds = outIdsBuf.Get<int32_t>();
|
||||
LocalTensor<int32_t> lPos = outPosBuf.Get<int32_t>();
|
||||
LocalTensor<int8_t> lRej = outRejBuf.Get<int8_t>();
|
||||
LocalTensor<int8_t> lMsk = outMskBuf.Get<int8_t>();
|
||||
|
||||
for (int32_t j = 0; j < numValid; j++) {
|
||||
int32_t inIdx = j;
|
||||
if (inIdx >= numInputTokensForReq) inIdx = numInputTokensForReq - 1;
|
||||
lIds.SetValue(j, localInput.GetValue(inIdx));
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
// Bonus
|
||||
lIds.SetValue(numValid, nextTokenId);
|
||||
lPos.SetValue(numValid, startPos + numValid);
|
||||
lRej.SetValue(numValid, (int8_t)0);
|
||||
lMsk.SetValue(numValid, (int8_t)0);
|
||||
// Parallel Draft
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
int32_t j = numValid + k;
|
||||
lIds.SetValue(j, parallelDraftingTokenId);
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)1);
|
||||
}
|
||||
// Rejected
|
||||
for (int32_t k = 0; k < numRejected; k++) {
|
||||
int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k;
|
||||
lIds.SetValue(j, paddingTokenId);
|
||||
lPos.SetValue(j, (int32_t)0);
|
||||
lRej.SetValue(j, (int8_t)1);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
|
||||
// UB → GM
|
||||
DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen);
|
||||
DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen);
|
||||
|
||||
// NTI
|
||||
LocalTensor<int32_t> lNti = ntiBuf.Get<int32_t>();
|
||||
lNti.SetValue(0, outputStart + numValid);
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
lNti.SetValue(k, outputStart + numValid + k);
|
||||
}
|
||||
int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq;
|
||||
DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// shift_input_ids = true
|
||||
// ============================================================
|
||||
__aicore__ inline void ProcessOneRequestShiftTrue(uint32_t r, uint32_t rLocal)
|
||||
{
|
||||
int32_t queryStart = ReadQS(rLocal);
|
||||
int32_t nextQueryStart = ReadNextQS(rLocal);
|
||||
int32_t queryEnd = ReadQE(rLocal);
|
||||
|
||||
int32_t numRejected = nextQueryStart - queryEnd - 1;
|
||||
if (numRejected < 0) numRejected = 0;
|
||||
int32_t numValid = queryEnd - queryStart;
|
||||
if (numValid < 0) numValid = 0;
|
||||
|
||||
int32_t outputStart = queryStart + (int32_t)r * ((int32_t)numPaddingSlotsPerReq - 1);
|
||||
int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected;
|
||||
|
||||
int32_t numInputTokensForReq = nextQueryStart - queryStart;
|
||||
LocalTensor<int32_t> localInput = inputBuf.Get<int32_t>();
|
||||
int32_t readStart = queryStart + 1;
|
||||
int32_t readCount = numValid;
|
||||
if (readStart + readCount > (int32_t)totalInputTokens) {
|
||||
readCount = (int32_t)totalInputTokens - readStart;
|
||||
if (readCount < 0) readCount = 0;
|
||||
}
|
||||
if (readCount > 0) {
|
||||
DataCopyIn(localInput, gmTargetTokenIds, readStart, readCount);
|
||||
}
|
||||
|
||||
LocalTensor<int32_t> localTmpPos = hsmBuf.Get<int32_t>();
|
||||
DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1);
|
||||
int32_t startPos = localTmpPos.GetValue(0);
|
||||
|
||||
int32_t nextTokenId = ReadNT(rLocal);
|
||||
|
||||
LocalTensor<int32_t> lIds = outIdsBuf.Get<int32_t>();
|
||||
LocalTensor<int32_t> lPos = outPosBuf.Get<int32_t>();
|
||||
LocalTensor<int8_t> lRej = outRejBuf.Get<int8_t>();
|
||||
LocalTensor<int8_t> lMsk = outMskBuf.Get<int8_t>();
|
||||
|
||||
for (int32_t j = 0; j < numValid; j++) {
|
||||
int32_t inIdx = j;
|
||||
if (inIdx >= readCount && readCount > 0) inIdx = readCount - 1;
|
||||
lIds.SetValue(j, readCount > 0 ? localInput.GetValue(inIdx) : (int32_t)0);
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
lIds.SetValue(numValid, nextTokenId);
|
||||
lPos.SetValue(numValid, startPos + numValid);
|
||||
lRej.SetValue(numValid, (int8_t)0);
|
||||
lMsk.SetValue(numValid, (int8_t)0);
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
int32_t j = numValid + k;
|
||||
lIds.SetValue(j, parallelDraftingTokenId);
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)1);
|
||||
}
|
||||
for (int32_t k = 0; k < numRejected; k++) {
|
||||
int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k;
|
||||
lIds.SetValue(j, paddingTokenId);
|
||||
lPos.SetValue(j, (int32_t)0);
|
||||
lRej.SetValue(j, (int8_t)1);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
|
||||
DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen);
|
||||
DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen);
|
||||
|
||||
LocalTensor<int32_t> lNti = ntiBuf.Get<int32_t>();
|
||||
lNti.SetValue(0, outputStart + numValid);
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
lNti.SetValue(k, outputStart + numValid + k);
|
||||
}
|
||||
int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq;
|
||||
DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq);
|
||||
|
||||
// hidden_state_mapping
|
||||
LocalTensor<int32_t> lHsm = hsmBuf.Get<int32_t>();
|
||||
for (int32_t j = 0; j < numInputTokensForReq; j++) {
|
||||
lHsm.SetValue(j, outputStart + j);
|
||||
}
|
||||
DataCopyOut_int32(gmOutHiddenStateMapping, lHsm, queryStart, numInputTokensForReq);
|
||||
}
|
||||
|
||||
private:
|
||||
GlobalTensor<int32_t> gmTargetTokenIds, gmTargetPositions, gmNextTokenIds;
|
||||
GlobalTensor<int32_t> gmQueryStartLoc, gmQueryEndLoc;
|
||||
GlobalTensor<int32_t> gmOutInputIds, gmOutPositions;
|
||||
GlobalTensor<int8_t> gmOutIsRejectedTokenMask, gmOutIsMaskedTokenMask;
|
||||
GlobalTensor<int32_t> gmOutNewTokenIndices, gmOutHiddenStateMapping;
|
||||
|
||||
uint32_t usedCoreNum, numReqs, reqsPerCore, remainderReqs;
|
||||
int32_t paddingTokenId, parallelDraftingTokenId;
|
||||
uint32_t numPaddingSlotsPerReq, totalInputTokens, totalDraftTokens;
|
||||
uint32_t myStartReq, myNumReqs;
|
||||
|
||||
TPipe pipe;
|
||||
TBuf<QuePosition::VECCALC> qsBuf, qeBuf, ntBuf;
|
||||
TBuf<QuePosition::VECCALC> inputBuf, outIdsBuf, outPosBuf;
|
||||
TBuf<QuePosition::VECCALC> outRejBuf, outMskBuf, ntiBuf, hsmBuf;
|
||||
};
|
||||
|
||||
extern "C" __global__ __aicore__ void copy_and_expand_eagle_inputs(
|
||||
GM_ADDR targetTokenIds, GM_ADDR targetPositions,
|
||||
GM_ADDR nextTokenIds, GM_ADDR queryStartLoc,
|
||||
GM_ADDR queryEndLoc,
|
||||
GM_ADDR outInputIds, GM_ADDR outPositions,
|
||||
GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask,
|
||||
GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping,
|
||||
GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
GET_TILING_DATA(tilingData, tiling);
|
||||
|
||||
if (GetBlockIdx() >= tilingData.usedCoreNum) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (TILING_KEY_IS(1)) {
|
||||
CopyAndExpandEagleInputsKernel op;
|
||||
op.Init(targetTokenIds, targetPositions, nextTokenIds, queryStartLoc, queryEndLoc,
|
||||
outInputIds, outPositions, outIsRejectedTokenMask, outIsMaskedTokenMask,
|
||||
outNewTokenIndices, outHiddenStateMapping, &tilingData);
|
||||
|
||||
if (tilingData.shiftInputIds == 0) {
|
||||
op.ProcessShiftFalse();
|
||||
} else {
|
||||
op.ProcessShiftTrue();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -597,6 +597,41 @@ void transpose_kv_cache_by_block(
|
||||
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
npu_copy_and_expand_eagle_inputs(
|
||||
const at::Tensor &target_token_ids,
|
||||
const at::Tensor &target_positions,
|
||||
const at::Tensor &next_token_ids,
|
||||
const at::Tensor &query_start_loc,
|
||||
const at::Tensor &query_end_loc,
|
||||
int64_t padding_token_id,
|
||||
int64_t parallel_drafting_token_id,
|
||||
int64_t num_padding_slots_per_request,
|
||||
bool shift_input_ids,
|
||||
int64_t total_draft_tokens)
|
||||
{
|
||||
int64_t total_input_tokens = target_token_ids.size(0);
|
||||
int64_t num_reqs = query_start_loc.size(0) - 1;
|
||||
|
||||
auto device = target_token_ids.device();
|
||||
at::Tensor out_input_ids = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device));
|
||||
at::Tensor out_positions = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device));
|
||||
at::Tensor out_is_rejected_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device));
|
||||
at::Tensor out_is_masked_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device));
|
||||
at::Tensor out_new_token_indices = at::empty({num_reqs * num_padding_slots_per_request}, at::dtype(at::kInt).device(device));
|
||||
at::Tensor out_hidden_state_mapping = at::empty({total_input_tokens}, at::dtype(at::kInt).device(device));
|
||||
|
||||
EXEC_NPU_CMD(aclnnCopyAndExpandEagleInputs,
|
||||
target_token_ids, target_positions, next_token_ids, query_start_loc, query_end_loc,
|
||||
padding_token_id, parallel_drafting_token_id, num_padding_slots_per_request,
|
||||
shift_input_ids, total_input_tokens,
|
||||
out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
|
||||
out_new_token_indices, out_hidden_state_mapping);
|
||||
|
||||
return {out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
|
||||
out_new_token_indices, out_hidden_state_mapping};
|
||||
}
|
||||
|
||||
at::Tensor causal_conv1d_fn(
|
||||
const at::Tensor& mixed_qkv_non_spec_T,
|
||||
const at::Tensor& conv_weights,
|
||||
@@ -849,6 +884,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
|
||||
);
|
||||
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
|
||||
|
||||
ops.def(
|
||||
"npu_copy_and_expand_eagle_inputs(Tensor target_token_ids, Tensor target_positions, "
|
||||
"Tensor next_token_ids, Tensor query_start_loc, Tensor query_end_loc, "
|
||||
"int padding_token_id, int parallel_drafting_token_id, int num_padding_slots_per_request, "
|
||||
"bool shift_input_ids, int total_draft_tokens) -> "
|
||||
"(Tensor out_input_ids, Tensor out_positions, Tensor out_is_rejected_token_mask, "
|
||||
"Tensor out_is_masked_token_mask, Tensor out_new_token_indices, Tensor out_hidden_state_mapping)"
|
||||
);
|
||||
ops.impl("npu_copy_and_expand_eagle_inputs", torch::kPrivateUse1, &vllm_ascend::npu_copy_and_expand_eagle_inputs);
|
||||
// causal_conv1d_fn
|
||||
ops.def(
|
||||
"causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, "
|
||||
|
||||
@@ -458,6 +458,33 @@ void transpose_kv_cache_by_block_meta(
|
||||
return;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
npu_copy_and_expand_eagle_inputs_meta(
|
||||
const at::Tensor &target_token_ids,
|
||||
const at::Tensor &target_positions,
|
||||
const at::Tensor &next_token_ids,
|
||||
const at::Tensor &query_start_loc,
|
||||
const at::Tensor &query_end_loc,
|
||||
int64_t padding_token_id,
|
||||
int64_t parallel_drafting_token_id,
|
||||
int64_t num_padding_slots_per_request,
|
||||
bool shift_input_ids,
|
||||
int64_t total_draft_tokens)
|
||||
{
|
||||
int64_t total_input_tokens = target_token_ids.size(0);
|
||||
int64_t num_reqs = query_start_loc.size(0) - 1;
|
||||
|
||||
at::Tensor out_input_ids = at::empty({total_draft_tokens}, target_token_ids.options());
|
||||
at::Tensor out_positions = at::empty({total_draft_tokens}, target_token_ids.options());
|
||||
at::Tensor out_is_rejected_token_mask = at::empty({total_draft_tokens}, target_token_ids.options().dtype(at::kChar));
|
||||
at::Tensor out_is_masked_token_mask = at::empty({total_draft_tokens}, target_token_ids.options().dtype(at::kChar));
|
||||
at::Tensor out_new_token_indices = at::empty({num_reqs * num_padding_slots_per_request}, target_token_ids.options());
|
||||
at::Tensor out_hidden_state_mapping = at::empty({total_input_tokens}, target_token_ids.options());
|
||||
|
||||
return {out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
|
||||
out_new_token_indices, out_hidden_state_mapping};
|
||||
}
|
||||
|
||||
at::Tensor causal_conv1d_fn_meta(
|
||||
const at::Tensor& mixed_qkv_non_spec_T,
|
||||
const at::Tensor& conv_weights,
|
||||
@@ -543,6 +570,8 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
|
||||
// transpose_kv_cache_by_block
|
||||
ops.impl("transpose_kv_cache_by_block", &vllm_ascend::meta::transpose_kv_cache_by_block_meta);
|
||||
// CopyAndExpandEagleInputs
|
||||
ops.impl("npu_copy_and_expand_eagle_inputs", &vllm_ascend::meta::npu_copy_and_expand_eagle_inputs_meta);
|
||||
// causal_conv1d_fn
|
||||
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
|
||||
// moe_grouped_matmul
|
||||
|
||||
Reference in New Issue
Block a user