[Performance]: Custom AscendC Kernel of Multi-Step Prepare Input (#814)
### What this PR does / why we need it? - According to https://github.com/vllm-project/vllm-ascend/issues/807, we pull request for customer ascendc kernel of multi-step. - also a bug we found in multi_step_runner.py is fixed when we use multi-step on V0 Engine. ### Does this PR introduce _any_ user-facing change? no user-facing change ### How was this patch tested? we add Unit Test file and offline inference file to test the custom ascendc kernel. See test/ops/test_multi_step.py and examples/offline_multi_step.py --------- Signed-off-by: wan_danfeng <wonderful199082@126.com>
This commit is contained in:
2
.github/workflows/codespell.yml
vendored
2
.github/workflows/codespell.yml
vendored
@@ -42,6 +42,6 @@ jobs:
|
|||||||
- name: Run codespell check
|
- name: Run codespell check
|
||||||
run: |
|
run: |
|
||||||
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
|
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
|
||||||
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue')
|
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn')
|
||||||
|
|
||||||
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
|
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ endif()
|
|||||||
|
|
||||||
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
|
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
|
||||||
file(GLOB KERNEL_FILES
|
file(GLOB KERNEL_FILES
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/pos_encoding_kernels.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp)
|
||||||
|
|
||||||
ascendc_library(vllm_ascend_kernels SHARED
|
ascendc_library(vllm_ascend_kernels SHARED
|
||||||
${KERNEL_FILES}
|
${KERNEL_FILES}
|
||||||
|
|||||||
241
csrc/kernels/advance_step.cpp
Normal file
241
csrc/kernels/advance_step.cpp
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) China Merchants Bank Co., Ltd. 2025. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "kernel_operator.h"
|
||||||
|
constexpr int32_t BUFFER_NUM = 1;
|
||||||
|
class KernelAdvanceStep{
|
||||||
|
public:
|
||||||
|
__aicore__ inline KernelAdvanceStep() {}
|
||||||
|
__aicore__ inline void Init(int32_t tasks_per_core,
|
||||||
|
int32_t num_queries,
|
||||||
|
__gm__ int64_t* input_tokens_ptr,
|
||||||
|
__gm__ int64_t* sampled_token_ids_ptr,
|
||||||
|
__gm__ int64_t* input_positions_ptr,
|
||||||
|
__gm__ int32_t* seq_lens_ptr,
|
||||||
|
__gm__ int32_t* slot_mapping_ptr)
|
||||||
|
{
|
||||||
|
this->tasks_per_core = tasks_per_core;
|
||||||
|
|
||||||
|
this->start_id = this->tasks_per_core * AscendC::GetBlockIdx();
|
||||||
|
this->end_id = this->tasks_per_core * (AscendC::GetBlockIdx() + 1) - 1;
|
||||||
|
|
||||||
|
// actual task nums of each core
|
||||||
|
this->actual_task_per_core = tasks_per_core;
|
||||||
|
if(this->end_id >= num_queries) {
|
||||||
|
this->actual_task_per_core = num_queries - this->start_id;
|
||||||
|
this->end_id = num_queries - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t offset_this_core = this->tasks_per_core * AscendC::GetBlockIdx();
|
||||||
|
|
||||||
|
// init outQues
|
||||||
|
pipe.InitBuffer(outQueInputTokens, BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t));
|
||||||
|
pipe.InitBuffer(outQueInputPos, BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t));
|
||||||
|
pipe.InitBuffer(outQueSeqLen, BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t));
|
||||||
|
pipe.InitBuffer(outQueSlotMapping, BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t));
|
||||||
|
|
||||||
|
// init inQues
|
||||||
|
pipe.InitBuffer(inQueSeqLen,BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t));
|
||||||
|
pipe.InitBuffer(inQueSampledTokenIds,BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t));
|
||||||
|
|
||||||
|
// init GlobalMemory
|
||||||
|
inputTokensGm.SetGlobalBuffer((__gm__ int64_t *)input_tokens_ptr + offset_this_core, this->actual_task_per_core);
|
||||||
|
sampledTokenIdsGm.SetGlobalBuffer((__gm__ int64_t *)sampled_token_ids_ptr + offset_this_core, this->actual_task_per_core);
|
||||||
|
inputPositionsGm.SetGlobalBuffer((__gm__ int64_t *)input_positions_ptr + offset_this_core, this->actual_task_per_core);
|
||||||
|
seqLensGm.SetGlobalBuffer((__gm__ int32_t *)seq_lens_ptr + offset_this_core, this->actual_task_per_core);
|
||||||
|
slotMappingGm.SetGlobalBuffer((__gm__ int32_t *)slot_mapping_ptr + offset_this_core, this->actual_task_per_core);
|
||||||
|
}
|
||||||
|
__aicore__ inline void Process(int64_t block_size, __gm__ int32_t* block_tables_ptr, int64_t block_tables_stride)
|
||||||
|
{
|
||||||
|
// no need for tilling or pipeline parallel within each core, as the amount of data processed is very small
|
||||||
|
CopyIn();
|
||||||
|
Update(block_size, block_tables_ptr, block_tables_stride);
|
||||||
|
CopyOut();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
__aicore__ inline void CopyIn()
|
||||||
|
{
|
||||||
|
AscendC::LocalTensor<int32_t> seqLenLocalIn = inQueSeqLen.AllocTensor<int32_t>();
|
||||||
|
AscendC::LocalTensor<int64_t> sampledTokenIdsLocal = inQueSampledTokenIds.AllocTensor<int64_t>();
|
||||||
|
|
||||||
|
AscendC::DataCopyExtParams copyParams32{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int32_t)), 0, 0, 0}; // blockLen = tasks_per_core * 32 / 8 个字节(int32为4字节)
|
||||||
|
AscendC::DataCopyExtParams copyParams64{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int64_t)), 0, 0, 0}; // blockLen = tasks_per_core * 64 / 8 个字节(int64为8字节)
|
||||||
|
|
||||||
|
// calculate the nums that need padded
|
||||||
|
// so that the total length becomes a multiple of 32 bytes which is a requirement of DataCopy Function.
|
||||||
|
uint8_t remainNum32 =this->actual_task_per_core * sizeof(int32_t) % 32;
|
||||||
|
uint8_t needPadElements32 = remainNum32 == 0 ? remainNum32 : (32 - remainNum32) / sizeof(int32_t);
|
||||||
|
|
||||||
|
AscendC::DataCopyPadExtParams<int32_t> padParams32{true, 0, needPadElements32, 0};
|
||||||
|
|
||||||
|
// calculate the nums that need padded
|
||||||
|
// so that the total length becomes a multiple of 32 bytes which is a requirement of DataCopy Function.
|
||||||
|
uint8_t remainNum64 =this->actual_task_per_core * sizeof(int64_t) % 32;
|
||||||
|
uint8_t needPadElements64 = remainNum64 == 0 ? remainNum64 : (32 - remainNum64) / sizeof(int64_t);
|
||||||
|
AscendC::DataCopyPadExtParams<int64_t> padParams64{true, 0, needPadElements64, 0};
|
||||||
|
|
||||||
|
AscendC::DataCopyPad(seqLenLocalIn, seqLensGm, copyParams32, padParams32);
|
||||||
|
AscendC::DataCopyPad(sampledTokenIdsLocal, sampledTokenIdsGm, copyParams64, padParams64);
|
||||||
|
|
||||||
|
inQueSeqLen.EnQue(seqLenLocalIn);
|
||||||
|
inQueSampledTokenIds.EnQue(sampledTokenIdsLocal);
|
||||||
|
}
|
||||||
|
__aicore__ inline void Update(int64_t block_size, __gm__ int32_t* block_tables_ptr, int64_t block_tables_stride)
|
||||||
|
{
|
||||||
|
// input
|
||||||
|
AscendC::LocalTensor<int32_t> seqLenLocalIn = inQueSeqLen.DeQue<int32_t>();
|
||||||
|
AscendC::LocalTensor<int64_t> sampledTokenIdsLocal = inQueSampledTokenIds.DeQue<int64_t>();
|
||||||
|
|
||||||
|
// output
|
||||||
|
AscendC::LocalTensor<int64_t> inputTokensLocal = outQueInputTokens.AllocTensor<int64_t>();
|
||||||
|
AscendC::LocalTensor<int64_t> inputPosLocal = outQueInputPos.AllocTensor<int64_t>();
|
||||||
|
AscendC::LocalTensor<int32_t> seqLenLocalOut = outQueSeqLen.AllocTensor<int32_t>();
|
||||||
|
AscendC::LocalTensor<int32_t> slotMappingLocal = outQueSlotMapping.AllocTensor<int32_t>();
|
||||||
|
|
||||||
|
auto unary_params = AscendC::UnaryRepeatParams(1, 1, 8, 8);
|
||||||
|
|
||||||
|
//Use "for" instead of AscendC::Adds function because AscendC::Adds does not work
|
||||||
|
//when srcLocalMemory has different datatype from dstLocalMemory
|
||||||
|
for(int i=0; i < this->actual_task_per_core; i++) {
|
||||||
|
inputTokensLocal.SetValue(i, sampledTokenIdsLocal.GetValue(i));
|
||||||
|
inputPosLocal.SetValue(i, seqLenLocalIn.GetValue(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
AscendC::Adds<int32_t, false>(seqLenLocalOut, seqLenLocalIn, 1, (uint64_t)0, 1, unary_params);
|
||||||
|
|
||||||
|
// Gather blockTables with dim=1, block_index. No Ascend Function available, use "for" instead.
|
||||||
|
for(int cur_query_id = this->start_id, i = 0; i < this->actual_task_per_core; cur_query_id++, i++) {
|
||||||
|
__gm__ int32_t const* seq_block_tables_ptr = block_tables_ptr + block_tables_stride * cur_query_id;
|
||||||
|
|
||||||
|
int block_index = inputPosLocal.GetValue(i) / block_size;
|
||||||
|
int block_offset = inputPosLocal.GetValue(i) % block_size;
|
||||||
|
|
||||||
|
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||||
|
// Update slot_mapping
|
||||||
|
slotMappingLocal.SetValue(i,slot_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
outQueInputTokens.EnQue(inputTokensLocal);
|
||||||
|
outQueInputPos.EnQue(inputPosLocal);
|
||||||
|
outQueSeqLen.EnQue(seqLenLocalOut);
|
||||||
|
outQueSlotMapping.EnQue(slotMappingLocal);
|
||||||
|
|
||||||
|
inQueSampledTokenIds.FreeTensor(sampledTokenIdsLocal);
|
||||||
|
inQueSeqLen.FreeTensor(seqLenLocalIn);
|
||||||
|
|
||||||
|
}
|
||||||
|
__aicore__ inline void CopyOut()
|
||||||
|
{
|
||||||
|
AscendC::DataCopyExtParams copyParams32{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int32_t)),0,0,0};
|
||||||
|
AscendC::DataCopyExtParams copyParams64{1, static_cast<uint32_t>(this->actual_task_per_core * sizeof(int64_t)),0,0,0};
|
||||||
|
|
||||||
|
AscendC::LocalTensor<int64_t> inputTokensLocal = outQueInputTokens.DeQue<int64_t>();
|
||||||
|
AscendC::DataCopyPad(inputTokensGm, inputTokensLocal, copyParams64);
|
||||||
|
outQueInputTokens.FreeTensor(inputTokensLocal);
|
||||||
|
|
||||||
|
AscendC::LocalTensor<int64_t> inputPosLocal = outQueInputPos.DeQue<int64_t>();
|
||||||
|
AscendC::DataCopyPad(inputPositionsGm, inputPosLocal, copyParams64);
|
||||||
|
outQueInputPos.FreeTensor(inputPosLocal);
|
||||||
|
|
||||||
|
AscendC::LocalTensor<int32_t> seqLenLocalOut = outQueSeqLen.DeQue<int32_t>();
|
||||||
|
AscendC::DataCopyPad(seqLensGm, seqLenLocalOut, copyParams32);
|
||||||
|
outQueSeqLen.FreeTensor(seqLenLocalOut);
|
||||||
|
|
||||||
|
AscendC::LocalTensor<int32_t> slotMappingLocal = outQueSlotMapping.DeQue<int32_t>();
|
||||||
|
AscendC::DataCopyPad(slotMappingGm, slotMappingLocal, copyParams32);
|
||||||
|
outQueSlotMapping.FreeTensor(slotMappingLocal);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AscendC::TPipe pipe;
|
||||||
|
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueInputTokens, outQueInputPos,
|
||||||
|
outQueSeqLen, outQueSlotMapping;
|
||||||
|
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueSeqLen,
|
||||||
|
inQueSampledTokenIds,
|
||||||
|
inQueBlockTables;
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<int64_t> inputTokensGm, sampledTokenIdsGm, inputPositionsGm ;
|
||||||
|
|
||||||
|
AscendC::GlobalTensor<int32_t> seqLensGm, slotMappingGm, blockTablesGm;
|
||||||
|
|
||||||
|
int32_t tasks_per_core, start_id, end_id, actual_task_per_core;
|
||||||
|
};
|
||||||
|
|
||||||
|
extern "C" __global__ __aicore__ void AdvanceStepFlashAttnKernel(
|
||||||
|
int64_t num_seqs,
|
||||||
|
int64_t num_queries,
|
||||||
|
int64_t block_size,
|
||||||
|
__gm__ int64_t* input_tokens_ptr,
|
||||||
|
__gm__ int64_t* sampled_token_ids_ptr,
|
||||||
|
__gm__ int64_t* input_positions_ptr,
|
||||||
|
__gm__ int32_t* seq_lens_ptr,
|
||||||
|
__gm__ int32_t* slot_mapping_ptr,
|
||||||
|
__gm__ int32_t* block_tables_ptr,
|
||||||
|
int64_t block_tables_stride,
|
||||||
|
int32_t tasks_per_core
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int start_id = tasks_per_core * AscendC::GetBlockIdx();
|
||||||
|
// no task for this core.
|
||||||
|
if(start_id >= num_queries) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
KernelAdvanceStep advanceStep;
|
||||||
|
advanceStep.Init(tasks_per_core, num_queries, input_tokens_ptr, sampled_token_ids_ptr, input_positions_ptr, seq_lens_ptr, slot_mapping_ptr);
|
||||||
|
advanceStep.Process(block_size,block_tables_ptr,block_tables_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace vllm_ascend
|
||||||
|
{
|
||||||
|
|
||||||
|
extern void launch_advance_step_flashattn(
|
||||||
|
void* stream,
|
||||||
|
int64_t num_seqs,
|
||||||
|
int64_t num_queries,
|
||||||
|
int64_t block_size,
|
||||||
|
int64_t* input_tokens_ptr,
|
||||||
|
int64_t* sampled_token_ids_ptr,
|
||||||
|
int64_t* input_positions_ptr,
|
||||||
|
int32_t* seq_lens_ptr,
|
||||||
|
int32_t* slot_mapping_ptr,
|
||||||
|
int32_t* block_tables_ptr,
|
||||||
|
int64_t block_tables_stride)
|
||||||
|
{
|
||||||
|
int32_t num_cores = 20;
|
||||||
|
|
||||||
|
if(num_cores > num_queries) {
|
||||||
|
num_cores = num_queries;
|
||||||
|
}
|
||||||
|
|
||||||
|
// task num processed of each core
|
||||||
|
int32_t tasks_per_core = (num_queries + num_cores - 1) / num_cores;
|
||||||
|
|
||||||
|
AdvanceStepFlashAttnKernel<<<num_cores, nullptr, stream>>>(
|
||||||
|
num_seqs,
|
||||||
|
num_queries,
|
||||||
|
block_size,
|
||||||
|
input_tokens_ptr,
|
||||||
|
sampled_token_ids_ptr,
|
||||||
|
input_positions_ptr,
|
||||||
|
seq_lens_ptr,
|
||||||
|
slot_mapping_ptr,
|
||||||
|
block_tables_ptr,
|
||||||
|
block_tables_stride,
|
||||||
|
tasks_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
12
csrc/ops.h
12
csrc/ops.h
@@ -46,4 +46,16 @@ namespace vllm_ascend {
|
|||||||
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
|
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
|
||||||
return new_tensor;
|
return new_tensor;
|
||||||
}
|
}
|
||||||
|
extern void launch_advance_step_flashattn(
|
||||||
|
void* stream,
|
||||||
|
int64_t num_seqs,
|
||||||
|
int64_t num_queries,
|
||||||
|
int64_t block_size,
|
||||||
|
int64_t* input_tokens_ptr,
|
||||||
|
int64_t* sampled_token_ids_ptr,
|
||||||
|
int64_t* input_positions_ptr,
|
||||||
|
int32_t* seq_lens_ptr,
|
||||||
|
int32_t* slot_mapping_ptr,
|
||||||
|
int32_t* block_tables_ptr,
|
||||||
|
int64_t block_tables_stride);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,6 +98,87 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
|||||||
cmd.Run();
|
cmd.Run();
|
||||||
return {query_dst, key_dst};
|
return {query_dst, key_dst};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void verify_tensor(std::string const& name, at::Tensor const& t,
|
||||||
|
int64_t const size_0, int64_t const size_1,
|
||||||
|
c10::ScalarType const type) {
|
||||||
|
bool size_0_cond = true;
|
||||||
|
if (size_0 != -1) {
|
||||||
|
size_0_cond = t.size(0) == size_0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool size_1_cond = true;
|
||||||
|
if (size_1 != -1) {
|
||||||
|
size_1_cond = t.size(1) == size_1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_contiguous = t.is_contiguous();
|
||||||
|
bool same_type = t.dtype() == type;
|
||||||
|
|
||||||
|
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
||||||
|
if (!pass) {
|
||||||
|
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
||||||
|
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
||||||
|
" is not as expected: shape = [", size_0, ", ", size_1,
|
||||||
|
"], type = ", type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void advance_step_flashattn_ascendc(
|
||||||
|
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||||
|
at::Tensor& input_tokens,
|
||||||
|
at::Tensor& sampled_token_ids,
|
||||||
|
at::Tensor& input_positions,
|
||||||
|
at::Tensor& seq_lens,
|
||||||
|
at::Tensor& slot_mapping,
|
||||||
|
at::Tensor& block_tables
|
||||||
|
){
|
||||||
|
// Verify all tensors
|
||||||
|
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,at::kLong);
|
||||||
|
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||||
|
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||||
|
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kInt);
|
||||||
|
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||||
|
|
||||||
|
|
||||||
|
int64_t* input_tokens_ptr = input_tokens.data_ptr<int64_t>();
|
||||||
|
int64_t* sampled_token_ids_ptr = sampled_token_ids.data_ptr<int64_t>();
|
||||||
|
int64_t* input_positions_ptr = input_positions.data_ptr<int64_t>();
|
||||||
|
int32_t* seq_lens_ptr = seq_lens.data_ptr<int32_t>();
|
||||||
|
int32_t* slot_mapping_ptr = slot_mapping.data_ptr<int32_t>();
|
||||||
|
int32_t* block_tables_ptr = block_tables.data_ptr<int32_t>();
|
||||||
|
|
||||||
|
|
||||||
|
int32_t device_id;
|
||||||
|
aclrtGetDevice(&device_id);
|
||||||
|
auto npu_stream = c10_npu::getCurrentNPUStream(device_id);
|
||||||
|
aclrtStream stream = npu_stream.stream();
|
||||||
|
|
||||||
|
// aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||||
|
at_npu::native::OpCommand cmd;
|
||||||
|
cmd.Name("advance_step_flashattn_ascendc");
|
||||||
|
cmd.SetCustomHandler([stream, num_seqs, num_queries,
|
||||||
|
block_size, input_tokens_ptr, sampled_token_ids_ptr,
|
||||||
|
input_positions_ptr, seq_lens_ptr, slot_mapping_ptr,
|
||||||
|
block_tables_ptr, block_tables]() -> int {
|
||||||
|
launch_advance_step_flashattn(stream,
|
||||||
|
num_seqs,
|
||||||
|
num_queries,
|
||||||
|
block_size,
|
||||||
|
input_tokens_ptr,
|
||||||
|
sampled_token_ids_ptr,
|
||||||
|
input_positions_ptr,
|
||||||
|
seq_lens_ptr,
|
||||||
|
slot_mapping_ptr,
|
||||||
|
block_tables_ptr,
|
||||||
|
block_tables.stride(0));
|
||||||
|
return 0;
|
||||||
|
});
|
||||||
|
cmd.Run();
|
||||||
|
return ;
|
||||||
|
}
|
||||||
} // namespace vllm_ascend
|
} // namespace vllm_ascend
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||||
@@ -113,6 +194,11 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
|||||||
" Tensor! key, int head_size,"
|
" Tensor! key, int head_size,"
|
||||||
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
||||||
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
||||||
|
ops.def(
|
||||||
|
"advance_step_flashattn_ascendc(int num_seqs, int num_queries, int block_size,"
|
||||||
|
" Tensor! input_tokens, Tensor! sampled_token_ids, Tensor! input_positions,"
|
||||||
|
" Tensor! seq_lens, Tensor! slot_mapping, Tensor! block_tables) -> ()");
|
||||||
|
ops.impl("advance_step_flashattn_ascendc", torch::kPrivateUse1, &vllm_ascend::advance_step_flashattn_ascendc);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(_C)
|
REGISTER_EXTENSION(_C)
|
||||||
|
|||||||
53
examples/offline_multi_step_custom_ops.py
Normal file
53
examples/offline_multi_step_custom_ops.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 China Merchants Bank Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
import vllm_ascend.platform as pf
|
||||||
|
|
||||||
|
pf.CUSTOM_OP_ENABLED = True # set True for custom Ops of Multi-Step.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
"China is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(
|
||||||
|
model="Qwen/Qwen2.5-0.5B",
|
||||||
|
block_size=128,
|
||||||
|
max_model_len=1024, # max length of prompt
|
||||||
|
tensor_parallel_size=1, # number of NPUs to be used
|
||||||
|
max_num_seqs=26, # max batch number
|
||||||
|
enforce_eager=
|
||||||
|
True, # Force PyTorch eager execution to debug intermediate tensors (disables graph optimizations)
|
||||||
|
trust_remote_code=
|
||||||
|
True, # If the model is a cuscd tom model not yet available in the HuggingFace transformers library
|
||||||
|
num_scheduler_steps=8,
|
||||||
|
gpu_memory_utilization=0.5)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
@@ -145,7 +145,7 @@ CODESPELL_EXCLUDES=(
|
|||||||
)
|
)
|
||||||
|
|
||||||
CODESPELL_IGNORE_WORDS=(
|
CODESPELL_IGNORE_WORDS=(
|
||||||
'-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue'
|
'-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn'
|
||||||
)
|
)
|
||||||
|
|
||||||
# check spelling of specified files
|
# check spelling of specified files
|
||||||
|
|||||||
190
tests/ops/test_multi_step.py
Normal file
190
tests/ops/test_multi_step.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
# Copyright (c) China Merchants Bank Co., Ltd. 2025. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#/
|
||||||
|
|
||||||
|
# to run this test, you need to cd to the upper package which is 'tests',
|
||||||
|
# and run with command 'pytest -s ops/test_multi_step.py'
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
|
DTYPES = [torch.int32, torch.int64]
|
||||||
|
DEVICES = [f"npu:{0}"]
|
||||||
|
# Set tolerance to 0 for equals
|
||||||
|
DEFAULT_ATOL = 0
|
||||||
|
DEFAULT_RTOL = 0
|
||||||
|
|
||||||
|
# test custom ops of https://github.com/vllm-project/vllm-ascend/tree/main/csrc/kernels/advance_step.cpp
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_single_generation_multi_step() -> None:
|
||||||
|
input_tokens_data = [2926]
|
||||||
|
input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0')
|
||||||
|
input_tokens_python = torch.tensor(input_tokens_data, device='npu:0')
|
||||||
|
|
||||||
|
sampled_token_ids_data = [[13]]
|
||||||
|
sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0')
|
||||||
|
|
||||||
|
input_positions_data = [5]
|
||||||
|
input_positions_ascendc = torch.tensor(input_positions_data,
|
||||||
|
device='npu:0')
|
||||||
|
input_positions_python = torch.tensor(input_positions_data, device='npu:0')
|
||||||
|
|
||||||
|
seq_lens_data = [6]
|
||||||
|
seq_lens_ascendc = torch.tensor(seq_lens_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
seq_lens_python = torch.tensor(seq_lens_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
slot_mapping_data = [5]
|
||||||
|
slot_mapping_ascendc = torch.tensor(slot_mapping_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
slot_mapping_python = torch.tensor(slot_mapping_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
block_tables_data = [[0]]
|
||||||
|
|
||||||
|
block_tables = torch.tensor(block_tables_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
torch.ops._C.advance_step_flashattn_ascendc(
|
||||||
|
1, 1, 128, input_tokens_ascendc, sampled_token_ids,
|
||||||
|
input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc,
|
||||||
|
block_tables)
|
||||||
|
|
||||||
|
normal(1, 1, 128, input_tokens_python, sampled_token_ids,
|
||||||
|
input_positions_python, seq_lens_python, slot_mapping_python,
|
||||||
|
block_tables)
|
||||||
|
|
||||||
|
# Compare the results.
|
||||||
|
torch.testing.assert_close(input_tokens_ascendc,
|
||||||
|
input_tokens_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
torch.testing.assert_close(input_positions_ascendc,
|
||||||
|
input_positions_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
torch.testing.assert_close(seq_lens_ascendc,
|
||||||
|
seq_lens_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
torch.testing.assert_close(slot_mapping_ascendc,
|
||||||
|
slot_mapping_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_multi_result_generation_multi_step() -> None:
|
||||||
|
input_tokens_data = [2926, 279, 12095, 1588]
|
||||||
|
input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0')
|
||||||
|
input_tokens_python = torch.tensor(input_tokens_data, device='npu:0')
|
||||||
|
|
||||||
|
sampled_token_ids_data = [[13], [1968], [13], [13]]
|
||||||
|
sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0')
|
||||||
|
|
||||||
|
input_positions_data = [5, 7, 5, 5]
|
||||||
|
input_positions_ascendc = torch.tensor(input_positions_data,
|
||||||
|
device='npu:0')
|
||||||
|
input_positions_python = torch.tensor(input_positions_data, device='npu:0')
|
||||||
|
|
||||||
|
seq_lens_data = [6, 8, 6, 6]
|
||||||
|
seq_lens_ascendc = torch.tensor(seq_lens_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
seq_lens_python = torch.tensor(seq_lens_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
slot_mapping_data = [5, 135, 261, 389]
|
||||||
|
slot_mapping_ascendc = torch.tensor(slot_mapping_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
slot_mapping_python = torch.tensor(slot_mapping_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
block_tables_data = [[0], [1], [2], [3]]
|
||||||
|
|
||||||
|
block_tables = torch.tensor(block_tables_data,
|
||||||
|
device='npu:0',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
torch.ops._C.advance_step_flashattn_ascendc(
|
||||||
|
4, 4, 128, input_tokens_ascendc, sampled_token_ids,
|
||||||
|
input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc,
|
||||||
|
block_tables)
|
||||||
|
|
||||||
|
normal(4, 4, 128, input_tokens_python, sampled_token_ids,
|
||||||
|
input_positions_python, seq_lens_python, slot_mapping_python,
|
||||||
|
block_tables)
|
||||||
|
|
||||||
|
# Compare the results.
|
||||||
|
torch.testing.assert_close(input_tokens_ascendc,
|
||||||
|
input_tokens_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
torch.testing.assert_close(input_positions_ascendc,
|
||||||
|
input_positions_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
torch.testing.assert_close(seq_lens_ascendc,
|
||||||
|
seq_lens_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
torch.testing.assert_close(slot_mapping_ascendc,
|
||||||
|
slot_mapping_python,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
|
||||||
|
def normal(num_seqs: int, num_queries: int, block_size: int,
|
||||||
|
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
|
||||||
|
input_positions: torch.Tensor, seq_lens_tensor: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor, block_tables: torch.Tensor) -> None:
|
||||||
|
sampled_token_ids_list = sampled_token_ids[:num_queries].squeeze(-1)
|
||||||
|
input_tokens[:num_queries] = sampled_token_ids_list
|
||||||
|
|
||||||
|
# get seq_lens and input_positions
|
||||||
|
seq_lens = seq_lens_tensor[:num_queries]
|
||||||
|
next_seq_lens = seq_lens + 1
|
||||||
|
next_input_pos = next_seq_lens - 1
|
||||||
|
|
||||||
|
# update seq_lens and input_positions
|
||||||
|
seq_lens_tensor[:num_queries] = next_seq_lens
|
||||||
|
input_positions[:num_queries] = next_input_pos # type: ignore
|
||||||
|
|
||||||
|
# get block index and offset
|
||||||
|
block_idx = next_input_pos // block_size
|
||||||
|
block_offset = next_input_pos % block_size
|
||||||
|
|
||||||
|
current_block_table = block_tables.gather(
|
||||||
|
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
||||||
|
slot_num = current_block_table * block_size + block_offset
|
||||||
|
|
||||||
|
# update slot_mapping
|
||||||
|
slot_mapping[:num_queries] = slot_num
|
||||||
@@ -36,6 +36,7 @@ from vllm.config import get_current_vllm_config
|
|||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
from vllm_ascend.ops.cache import concat_and_cache_mla
|
from vllm_ascend.ops.cache import concat_and_cache_mla
|
||||||
|
from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
||||||
from vllm_ascend.worker.model_runner import (
|
from vllm_ascend.worker.model_runner import (
|
||||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||||
|
|
||||||
@@ -459,36 +460,47 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
for i in range(num_queries):
|
for i in range(num_queries):
|
||||||
self.seq_lens[i] += 1
|
self.seq_lens[i] += 1
|
||||||
self.max_decode_seq_len = max(self.seq_lens)
|
self.max_decode_seq_len = max(self.seq_lens)
|
||||||
|
if CUSTOM_OP_ENABLED:
|
||||||
|
#advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
|
||||||
|
torch.ops._C.advance_step_flashattn_ascendc(
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
num_queries=num_queries,
|
||||||
|
block_size=block_size,
|
||||||
|
input_tokens=model_input.input_tokens,
|
||||||
|
sampled_token_ids=sampled_token_ids,
|
||||||
|
input_positions=model_input.input_positions,
|
||||||
|
seq_lens=self.seq_lens_tensor,
|
||||||
|
slot_mapping=self.slot_mapping,
|
||||||
|
block_tables=self.block_tables)
|
||||||
|
else:
|
||||||
|
# use traditional Pytorch method for updating these tensors.
|
||||||
|
# update input_tokens
|
||||||
|
sampled_token_ids_list = sampled_token_ids[:
|
||||||
|
num_queries].squeeze( # type: ignore
|
||||||
|
-1)
|
||||||
|
model_input.input_tokens[:
|
||||||
|
num_queries] = sampled_token_ids_list # type: ignore
|
||||||
|
|
||||||
# TODO optimize these codes using ascendc just like flash attention backend using cuda
|
# get seq_lens and input_positions
|
||||||
|
seq_lens = self.seq_lens_tensor[:num_queries]
|
||||||
|
next_seq_lens = seq_lens + 1
|
||||||
|
next_input_pos = next_seq_lens - 1
|
||||||
|
|
||||||
# update input_tokens
|
# update seq_lens and input_positions
|
||||||
sampled_token_ids_list = sampled_token_ids[:
|
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
||||||
num_queries].squeeze( # type: ignore
|
model_input.input_positions[:
|
||||||
-1)
|
num_queries] = next_input_pos # type: ignore
|
||||||
model_input.input_tokens[:
|
|
||||||
num_queries] = sampled_token_ids_list # type: ignore
|
|
||||||
|
|
||||||
# get seq_lens and input_positions
|
# 计算 block index 和 offset
|
||||||
seq_lens = self.seq_lens_tensor[:num_queries]
|
block_idx = next_input_pos // block_size
|
||||||
next_seq_lens = seq_lens + 1
|
block_offset = next_input_pos % block_size
|
||||||
next_input_pos = next_seq_lens - 1
|
|
||||||
|
|
||||||
# update seq_lens and input_positions
|
current_block_table = self.block_tables.gather(
|
||||||
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
||||||
model_input.input_positions[:
|
slot_num = current_block_table * block_size + block_offset
|
||||||
num_queries] = next_input_pos # type: ignore
|
|
||||||
|
|
||||||
# 计算 block index 和 offset
|
# update slot_mapping
|
||||||
block_idx = next_input_pos // block_size
|
self.slot_mapping[:num_queries] = slot_num
|
||||||
block_offset = next_input_pos % block_size
|
|
||||||
|
|
||||||
current_block_table = self.block_tables.gather(
|
|
||||||
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
|
||||||
slot_num = current_block_table * block_size + block_offset
|
|
||||||
|
|
||||||
# update slot_mapping
|
|
||||||
self.slot_mapping[:num_queries] = slot_num
|
|
||||||
|
|
||||||
|
|
||||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||||
@@ -749,11 +761,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache: shape = [2, num_blocks, block_size,
|
kv_cache: shape = [2, num_blocks, block_size,
|
||||||
num_kv_heads * head_size]
|
num_kv_heads, head_size]
|
||||||
key_cache = [num_blocks, block_size,
|
key_cache = [num_blocks, block_size,
|
||||||
num_kv_heads * head_size]
|
num_kv_heads, head_size]
|
||||||
value_cache = [num_blocks, block_size,
|
value_cache = [num_blocks, block_size,
|
||||||
num_kv_heads * head_size]
|
num_kv_heads, head_size]
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [batch_size, seq_len * num_heads * head_size]
|
shape = [batch_size, seq_len * num_heads * head_size]
|
||||||
|
|||||||
@@ -220,11 +220,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
kv_cache: shape = [2, num_blocks, block_size,
|
kv_cache: shape = [2, num_blocks, block_size,
|
||||||
num_kv_heads * head_size]
|
num_kv_heads, head_size]
|
||||||
key_cache = [num_blocks, block_size,
|
key_cache = [num_blocks, block_size,
|
||||||
num_kv_heads * head_size]
|
num_kv_heads, head_size]
|
||||||
value_cache = [num_blocks, block_size,
|
value_cache = [num_blocks, block_size,
|
||||||
num_kv_heads * head_size]
|
num_kv_heads, head_size]
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [batch_size * seq_len, num_heads, head_size]
|
shape = [batch_size * seq_len, num_heads, head_size]
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
|||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||||
from vllm.utils import current_stream
|
|
||||||
from vllm.worker.model_runner_base import (
|
from vllm.worker.model_runner_base import (
|
||||||
_init_attn_metadata_from_tensor_dict,
|
_init_attn_metadata_from_tensor_dict,
|
||||||
_init_frozen_model_input_from_tensor_dict,
|
_init_frozen_model_input_from_tensor_dict,
|
||||||
@@ -23,6 +22,7 @@ from vllm.worker.multi_step_model_runner import (ModelOutput,
|
|||||||
PythonizationCache,
|
PythonizationCache,
|
||||||
StatefulModelInput)
|
StatefulModelInput)
|
||||||
|
|
||||||
|
from vllm_ascend.utils import current_stream
|
||||||
from vllm_ascend.worker.model_runner import (
|
from vllm_ascend.worker.model_runner import (
|
||||||
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user