[feat][spec decode]Unified draft parallel (#6766)
### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078
### How was this patch tested?
run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811
benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
--temperature 0 \
--model /model/Llama-3.1-8B-Instruct \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path philschmid/mt-bench \
--num-prompts ${NUM_PROMPTS} \
--max-concurrency ${MAX_CONCURRENCY} \
--seed 1234
test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.
----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
cc @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
@@ -0,0 +1,386 @@
|
||||
/**
|
||||
* CopyAndExpandEagleInputs 算子 Kernel 实现 (DataCopy 版)
|
||||
*
|
||||
* 多核策略:
|
||||
* 所有 GM 读写通过 DataCopy 完成(不使用 GlobalTensor::SetValue/GetValue 访问 GM)。
|
||||
* UB (LocalTensor) 上使用 SetValue/GetValue 构建数据,再 DataCopy 到 GM。
|
||||
* 对齐处理参考 CANN 内置算子的 DataCopyCustom 模式。
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
|
||||
using namespace AscendC;
|
||||
|
||||
// ONE_BLK_SIZE comes from AscendC namespace (32 bytes per block)
|
||||
|
||||
class CopyAndExpandEagleInputsKernel {
|
||||
public:
|
||||
__aicore__ inline CopyAndExpandEagleInputsKernel() {}
|
||||
|
||||
__aicore__ inline void Init(GM_ADDR targetTokenIds, GM_ADDR targetPositions,
|
||||
GM_ADDR nextTokenIds, GM_ADDR queryStartLoc,
|
||||
GM_ADDR queryEndLoc,
|
||||
GM_ADDR outInputIds, GM_ADDR outPositions,
|
||||
GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask,
|
||||
GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping,
|
||||
const CopyAndExpandEagleInputsTilingData* tilingData)
|
||||
{
|
||||
usedCoreNum = tilingData->usedCoreNum;
|
||||
numReqs = tilingData->numReqs;
|
||||
reqsPerCore = tilingData->reqsPerCore;
|
||||
remainderReqs = tilingData->remainderReqs;
|
||||
paddingTokenId = tilingData->paddingTokenId;
|
||||
parallelDraftingTokenId = tilingData->parallelDraftingTokenId;
|
||||
numPaddingSlotsPerReq = tilingData->numPaddingSlotsPerReq;
|
||||
totalInputTokens = tilingData->totalInputTokens;
|
||||
totalDraftTokens = tilingData->totalDraftTokens;
|
||||
|
||||
uint32_t coreId = GetBlockIdx();
|
||||
if (coreId < remainderReqs) {
|
||||
myStartReq = coreId * (reqsPerCore + 1);
|
||||
myNumReqs = reqsPerCore + 1;
|
||||
} else {
|
||||
myStartReq = remainderReqs * (reqsPerCore + 1) + (coreId - remainderReqs) * reqsPerCore;
|
||||
myNumReqs = reqsPerCore;
|
||||
}
|
||||
|
||||
// 绑定 GM Tensor
|
||||
gmTargetTokenIds.SetGlobalBuffer((__gm__ int32_t*)targetTokenIds, totalInputTokens);
|
||||
gmTargetPositions.SetGlobalBuffer((__gm__ int32_t*)targetPositions, totalInputTokens);
|
||||
gmNextTokenIds.SetGlobalBuffer((__gm__ int32_t*)nextTokenIds, numReqs);
|
||||
gmQueryStartLoc.SetGlobalBuffer((__gm__ int32_t*)queryStartLoc, numReqs + 1);
|
||||
gmQueryEndLoc.SetGlobalBuffer((__gm__ int32_t*)queryEndLoc, numReqs);
|
||||
gmOutInputIds.SetGlobalBuffer((__gm__ int32_t*)outInputIds, totalDraftTokens);
|
||||
gmOutPositions.SetGlobalBuffer((__gm__ int32_t*)outPositions, totalDraftTokens);
|
||||
gmOutIsRejectedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsRejectedTokenMask, totalDraftTokens);
|
||||
gmOutIsMaskedTokenMask.SetGlobalBuffer((__gm__ int8_t*)outIsMaskedTokenMask, totalDraftTokens);
|
||||
gmOutNewTokenIndices.SetGlobalBuffer((__gm__ int32_t*)outNewTokenIndices, numPaddingSlotsPerReq * numReqs);
|
||||
gmOutHiddenStateMapping.SetGlobalBuffer((__gm__ int32_t*)outHiddenStateMapping, totalInputTokens);
|
||||
|
||||
// 分配 UB 缓冲区 —— 每个 TBuf 的基地址自动 32 字节对齐
|
||||
// 元数据各自独立 TBuf,避免 UB 地址不对齐
|
||||
uint32_t metaAligned = AlignUp((myNumReqs + 1) * sizeof(int32_t), ONE_BLK_SIZE);
|
||||
pipe.InitBuffer(qsBuf, metaAligned);
|
||||
pipe.InitBuffer(qeBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(ntBuf, AlignUp(myNumReqs * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
|
||||
// I/O 缓冲区
|
||||
constexpr uint32_t MAX_PER_REQ = 4096;
|
||||
pipe.InitBuffer(inputBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outIdsBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outPosBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outRejBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(outMskBuf, AlignUp(MAX_PER_REQ * sizeof(int8_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(ntiBuf, AlignUp(64 * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
pipe.InitBuffer(hsmBuf, AlignUp(MAX_PER_REQ * sizeof(int32_t), ONE_BLK_SIZE));
|
||||
|
||||
// DataCopy 元数据到各自 UB
|
||||
if (myNumReqs > 0) {
|
||||
LocalTensor<int32_t> lqs = qsBuf.Get<int32_t>();
|
||||
DataCopyIn(lqs, gmQueryStartLoc, (int32_t)myStartReq, (int32_t)(myNumReqs + 1));
|
||||
|
||||
LocalTensor<int32_t> lqe = qeBuf.Get<int32_t>();
|
||||
DataCopyIn(lqe, gmQueryEndLoc, (int32_t)myStartReq, (int32_t)myNumReqs);
|
||||
|
||||
LocalTensor<int32_t> lnt = ntBuf.Get<int32_t>();
|
||||
DataCopyIn(lnt, gmNextTokenIds, (int32_t)myStartReq, (int32_t)myNumReqs);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessShiftFalse()
|
||||
{
|
||||
for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) {
|
||||
ProcessOneRequestShiftFalse(myStartReq + rLocal, rLocal);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void ProcessShiftTrue()
|
||||
{
|
||||
for (uint32_t rLocal = 0; rLocal < myNumReqs; rLocal++) {
|
||||
ProcessOneRequestShiftTrue(myStartReq + rLocal, rLocal);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// ============================================================
|
||||
// AlignUp 辅助
|
||||
// ============================================================
|
||||
static __aicore__ inline uint32_t AlignUp(uint32_t x, uint32_t a)
|
||||
{
|
||||
return (x + a - 1) / a * a;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GM → UB: 标准 DataCopy,count 自动 round-up 到 block 对齐
|
||||
// 多读的元素在 UB 中不会被使用,安全无害
|
||||
// ============================================================
|
||||
__aicore__ inline void DataCopyIn(LocalTensor<int32_t>& dst,
|
||||
GlobalTensor<int32_t>& src,
|
||||
int32_t gmOffset, int32_t count)
|
||||
{
|
||||
if (count <= 0) return;
|
||||
constexpr int32_t ELEMS_PER_BLK = ONE_BLK_SIZE / (int32_t)sizeof(int32_t); // 8
|
||||
int32_t aligned = (count + ELEMS_PER_BLK - 1) / ELEMS_PER_BLK * ELEMS_PER_BLK;
|
||||
DataCopy(dst, src[gmOffset], aligned);
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// UB → GM: DataCopyPad + DataCopyExtParams(C220 支持任意字节数)
|
||||
// 精确写入 count 个元素,不越界覆盖相邻数据
|
||||
// ============================================================
|
||||
__aicore__ inline void DataCopyOut_int32(GlobalTensor<int32_t>& dst,
|
||||
LocalTensor<int32_t>& src,
|
||||
int32_t gmOffset, int32_t count)
|
||||
{
|
||||
if (count <= 0) return;
|
||||
uint32_t totalBytes = static_cast<uint32_t>(count) * static_cast<uint32_t>(sizeof(int32_t));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
|
||||
__aicore__ inline void DataCopyOut_int8(GlobalTensor<int8_t>& dst,
|
||||
LocalTensor<int8_t>& src,
|
||||
int32_t gmOffset, int32_t count)
|
||||
{
|
||||
if (count <= 0) return;
|
||||
uint32_t totalBytes = static_cast<uint32_t>(count) * static_cast<uint32_t>(sizeof(int8_t));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
DataCopyPad(dst[gmOffset], src, DataCopyExtParams(1, totalBytes, 0, 0, 0));
|
||||
pipe_barrier(PIPE_ALL);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 元数据读取 (从各自 UB 缓冲区)
|
||||
// ============================================================
|
||||
__aicore__ inline int32_t ReadQS(uint32_t rLocal) {
|
||||
return qsBuf.Get<int32_t>().GetValue(rLocal);
|
||||
}
|
||||
__aicore__ inline int32_t ReadNextQS(uint32_t rLocal) {
|
||||
return qsBuf.Get<int32_t>().GetValue(rLocal + 1);
|
||||
}
|
||||
__aicore__ inline int32_t ReadQE(uint32_t rLocal) {
|
||||
return qeBuf.Get<int32_t>().GetValue(rLocal);
|
||||
}
|
||||
__aicore__ inline int32_t ReadNT(uint32_t rLocal) {
|
||||
return ntBuf.Get<int32_t>().GetValue(rLocal);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// shift_input_ids = false
|
||||
// ============================================================
|
||||
__aicore__ inline void ProcessOneRequestShiftFalse(uint32_t r, uint32_t rLocal)
|
||||
{
|
||||
int32_t queryStart = ReadQS(rLocal);
|
||||
int32_t nextQueryStart = ReadNextQS(rLocal);
|
||||
int32_t queryEnd = ReadQE(rLocal);
|
||||
|
||||
int32_t numRejected = nextQueryStart - queryEnd - 1;
|
||||
if (numRejected < 0) numRejected = 0;
|
||||
int32_t numValid = queryEnd - queryStart + 1;
|
||||
if (numValid < 0) numValid = 0;
|
||||
|
||||
int32_t outputStart = queryStart + (int32_t)r * (int32_t)numPaddingSlotsPerReq;
|
||||
int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected;
|
||||
|
||||
// 读取输入 token 到 UB
|
||||
int32_t numInputTokensForReq = nextQueryStart - queryStart;
|
||||
LocalTensor<int32_t> localInput = inputBuf.Get<int32_t>();
|
||||
if (numInputTokensForReq > 0) {
|
||||
DataCopyIn(localInput, gmTargetTokenIds, queryStart, numInputTokensForReq);
|
||||
}
|
||||
|
||||
// 读取起始 position
|
||||
LocalTensor<int32_t> localTmpPos = hsmBuf.Get<int32_t>();
|
||||
DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1);
|
||||
int32_t startPos = localTmpPos.GetValue(0);
|
||||
|
||||
int32_t nextTokenId = ReadNT(rLocal);
|
||||
|
||||
// 构建输出到 UB
|
||||
LocalTensor<int32_t> lIds = outIdsBuf.Get<int32_t>();
|
||||
LocalTensor<int32_t> lPos = outPosBuf.Get<int32_t>();
|
||||
LocalTensor<int8_t> lRej = outRejBuf.Get<int8_t>();
|
||||
LocalTensor<int8_t> lMsk = outMskBuf.Get<int8_t>();
|
||||
|
||||
for (int32_t j = 0; j < numValid; j++) {
|
||||
int32_t inIdx = j;
|
||||
if (inIdx >= numInputTokensForReq) inIdx = numInputTokensForReq - 1;
|
||||
lIds.SetValue(j, localInput.GetValue(inIdx));
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
// Bonus
|
||||
lIds.SetValue(numValid, nextTokenId);
|
||||
lPos.SetValue(numValid, startPos + numValid);
|
||||
lRej.SetValue(numValid, (int8_t)0);
|
||||
lMsk.SetValue(numValid, (int8_t)0);
|
||||
// Parallel Draft
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
int32_t j = numValid + k;
|
||||
lIds.SetValue(j, parallelDraftingTokenId);
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)1);
|
||||
}
|
||||
// Rejected
|
||||
for (int32_t k = 0; k < numRejected; k++) {
|
||||
int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k;
|
||||
lIds.SetValue(j, paddingTokenId);
|
||||
lPos.SetValue(j, (int32_t)0);
|
||||
lRej.SetValue(j, (int8_t)1);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
|
||||
// UB → GM
|
||||
DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen);
|
||||
DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen);
|
||||
|
||||
// NTI
|
||||
LocalTensor<int32_t> lNti = ntiBuf.Get<int32_t>();
|
||||
lNti.SetValue(0, outputStart + numValid);
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
lNti.SetValue(k, outputStart + numValid + k);
|
||||
}
|
||||
int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq;
|
||||
DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// shift_input_ids = true
|
||||
// ============================================================
|
||||
__aicore__ inline void ProcessOneRequestShiftTrue(uint32_t r, uint32_t rLocal)
|
||||
{
|
||||
int32_t queryStart = ReadQS(rLocal);
|
||||
int32_t nextQueryStart = ReadNextQS(rLocal);
|
||||
int32_t queryEnd = ReadQE(rLocal);
|
||||
|
||||
int32_t numRejected = nextQueryStart - queryEnd - 1;
|
||||
if (numRejected < 0) numRejected = 0;
|
||||
int32_t numValid = queryEnd - queryStart;
|
||||
if (numValid < 0) numValid = 0;
|
||||
|
||||
int32_t outputStart = queryStart + (int32_t)r * ((int32_t)numPaddingSlotsPerReq - 1);
|
||||
int32_t outputLen = numValid + (int32_t)numPaddingSlotsPerReq + numRejected;
|
||||
|
||||
int32_t numInputTokensForReq = nextQueryStart - queryStart;
|
||||
LocalTensor<int32_t> localInput = inputBuf.Get<int32_t>();
|
||||
int32_t readStart = queryStart + 1;
|
||||
int32_t readCount = numValid;
|
||||
if (readStart + readCount > (int32_t)totalInputTokens) {
|
||||
readCount = (int32_t)totalInputTokens - readStart;
|
||||
if (readCount < 0) readCount = 0;
|
||||
}
|
||||
if (readCount > 0) {
|
||||
DataCopyIn(localInput, gmTargetTokenIds, readStart, readCount);
|
||||
}
|
||||
|
||||
LocalTensor<int32_t> localTmpPos = hsmBuf.Get<int32_t>();
|
||||
DataCopyIn(localTmpPos, gmTargetPositions, queryStart, 1);
|
||||
int32_t startPos = localTmpPos.GetValue(0);
|
||||
|
||||
int32_t nextTokenId = ReadNT(rLocal);
|
||||
|
||||
LocalTensor<int32_t> lIds = outIdsBuf.Get<int32_t>();
|
||||
LocalTensor<int32_t> lPos = outPosBuf.Get<int32_t>();
|
||||
LocalTensor<int8_t> lRej = outRejBuf.Get<int8_t>();
|
||||
LocalTensor<int8_t> lMsk = outMskBuf.Get<int8_t>();
|
||||
|
||||
for (int32_t j = 0; j < numValid; j++) {
|
||||
int32_t inIdx = j;
|
||||
if (inIdx >= readCount && readCount > 0) inIdx = readCount - 1;
|
||||
lIds.SetValue(j, readCount > 0 ? localInput.GetValue(inIdx) : (int32_t)0);
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
lIds.SetValue(numValid, nextTokenId);
|
||||
lPos.SetValue(numValid, startPos + numValid);
|
||||
lRej.SetValue(numValid, (int8_t)0);
|
||||
lMsk.SetValue(numValid, (int8_t)0);
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
int32_t j = numValid + k;
|
||||
lIds.SetValue(j, parallelDraftingTokenId);
|
||||
lPos.SetValue(j, startPos + j);
|
||||
lRej.SetValue(j, (int8_t)0);
|
||||
lMsk.SetValue(j, (int8_t)1);
|
||||
}
|
||||
for (int32_t k = 0; k < numRejected; k++) {
|
||||
int32_t j = numValid + (int32_t)numPaddingSlotsPerReq + k;
|
||||
lIds.SetValue(j, paddingTokenId);
|
||||
lPos.SetValue(j, (int32_t)0);
|
||||
lRej.SetValue(j, (int8_t)1);
|
||||
lMsk.SetValue(j, (int8_t)0);
|
||||
}
|
||||
|
||||
DataCopyOut_int32(gmOutInputIds, lIds, outputStart, outputLen);
|
||||
DataCopyOut_int32(gmOutPositions, lPos, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsRejectedTokenMask, lRej, outputStart, outputLen);
|
||||
DataCopyOut_int8(gmOutIsMaskedTokenMask, lMsk, outputStart, outputLen);
|
||||
|
||||
LocalTensor<int32_t> lNti = ntiBuf.Get<int32_t>();
|
||||
lNti.SetValue(0, outputStart + numValid);
|
||||
for (int32_t k = 1; k < (int32_t)numPaddingSlotsPerReq; k++) {
|
||||
lNti.SetValue(k, outputStart + numValid + k);
|
||||
}
|
||||
int32_t ntiOff = (int32_t)r * (int32_t)numPaddingSlotsPerReq;
|
||||
DataCopyOut_int32(gmOutNewTokenIndices, lNti, ntiOff, (int32_t)numPaddingSlotsPerReq);
|
||||
|
||||
// hidden_state_mapping
|
||||
LocalTensor<int32_t> lHsm = hsmBuf.Get<int32_t>();
|
||||
for (int32_t j = 0; j < numInputTokensForReq; j++) {
|
||||
lHsm.SetValue(j, outputStart + j);
|
||||
}
|
||||
DataCopyOut_int32(gmOutHiddenStateMapping, lHsm, queryStart, numInputTokensForReq);
|
||||
}
|
||||
|
||||
private:
|
||||
GlobalTensor<int32_t> gmTargetTokenIds, gmTargetPositions, gmNextTokenIds;
|
||||
GlobalTensor<int32_t> gmQueryStartLoc, gmQueryEndLoc;
|
||||
GlobalTensor<int32_t> gmOutInputIds, gmOutPositions;
|
||||
GlobalTensor<int8_t> gmOutIsRejectedTokenMask, gmOutIsMaskedTokenMask;
|
||||
GlobalTensor<int32_t> gmOutNewTokenIndices, gmOutHiddenStateMapping;
|
||||
|
||||
uint32_t usedCoreNum, numReqs, reqsPerCore, remainderReqs;
|
||||
int32_t paddingTokenId, parallelDraftingTokenId;
|
||||
uint32_t numPaddingSlotsPerReq, totalInputTokens, totalDraftTokens;
|
||||
uint32_t myStartReq, myNumReqs;
|
||||
|
||||
TPipe pipe;
|
||||
TBuf<QuePosition::VECCALC> qsBuf, qeBuf, ntBuf;
|
||||
TBuf<QuePosition::VECCALC> inputBuf, outIdsBuf, outPosBuf;
|
||||
TBuf<QuePosition::VECCALC> outRejBuf, outMskBuf, ntiBuf, hsmBuf;
|
||||
};
|
||||
|
||||
extern "C" __global__ __aicore__ void copy_and_expand_eagle_inputs(
|
||||
GM_ADDR targetTokenIds, GM_ADDR targetPositions,
|
||||
GM_ADDR nextTokenIds, GM_ADDR queryStartLoc,
|
||||
GM_ADDR queryEndLoc,
|
||||
GM_ADDR outInputIds, GM_ADDR outPositions,
|
||||
GM_ADDR outIsRejectedTokenMask, GM_ADDR outIsMaskedTokenMask,
|
||||
GM_ADDR outNewTokenIndices, GM_ADDR outHiddenStateMapping,
|
||||
GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
GET_TILING_DATA(tilingData, tiling);
|
||||
|
||||
if (GetBlockIdx() >= tilingData.usedCoreNum) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (TILING_KEY_IS(1)) {
|
||||
CopyAndExpandEagleInputsKernel op;
|
||||
op.Init(targetTokenIds, targetPositions, nextTokenIds, queryStartLoc, queryEndLoc,
|
||||
outInputIds, outPositions, outIsRejectedTokenMask, outIsMaskedTokenMask,
|
||||
outNewTokenIndices, outHiddenStateMapping, &tilingData);
|
||||
|
||||
if (tilingData.shiftInputIds == 0) {
|
||||
op.ProcessShiftFalse();
|
||||
} else {
|
||||
op.ProcessShiftTrue();
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user