diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml index c97fc00..a239d38 100644 --- a/.github/workflows/codespell.yml +++ b/.github/workflows/codespell.yml @@ -42,6 +42,6 @@ jobs: - name: Run codespell check run: | CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') - CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue') + CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn') codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" diff --git a/CMakeLists.txt b/CMakeLists.txt index 2db15d9..a2c3ad2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ endif() include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) file(GLOB KERNEL_FILES -${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/pos_encoding_kernels.cpp) +${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp) ascendc_library(vllm_ascend_kernels SHARED ${KERNEL_FILES} diff --git a/csrc/kernels/advance_step.cpp b/csrc/kernels/advance_step.cpp new file mode 100644 index 0000000..87a30bd --- /dev/null +++ b/csrc/kernels/advance_step.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) China Merchants Bank Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel_operator.h" +constexpr int32_t BUFFER_NUM = 1; +class KernelAdvanceStep{ +public: + __aicore__ inline KernelAdvanceStep() {} + __aicore__ inline void Init(int32_t tasks_per_core, + int32_t num_queries, + __gm__ int64_t* input_tokens_ptr, + __gm__ int64_t* sampled_token_ids_ptr, + __gm__ int64_t* input_positions_ptr, + __gm__ int32_t* seq_lens_ptr, + __gm__ int32_t* slot_mapping_ptr) + { + this->tasks_per_core = tasks_per_core; + + this->start_id = this->tasks_per_core * AscendC::GetBlockIdx(); + this->end_id = this->tasks_per_core * (AscendC::GetBlockIdx() + 1) - 1; + + // actual task nums of each core + this->actual_task_per_core = tasks_per_core; + if(this->end_id >= num_queries) { + this->actual_task_per_core = num_queries - this->start_id; + this->end_id = num_queries - 1; + } + + int32_t offset_this_core = this->tasks_per_core * AscendC::GetBlockIdx(); + + // init outQues + pipe.InitBuffer(outQueInputTokens, BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t)); + pipe.InitBuffer(outQueInputPos, BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t)); + pipe.InitBuffer(outQueSeqLen, BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t)); + pipe.InitBuffer(outQueSlotMapping, BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t)); + + // init inQues + pipe.InitBuffer(inQueSeqLen,BUFFER_NUM, this->actual_task_per_core * sizeof(int32_t)); + pipe.InitBuffer(inQueSampledTokenIds,BUFFER_NUM, this->actual_task_per_core * sizeof(int64_t)); + + // init GlobalMemory + inputTokensGm.SetGlobalBuffer((__gm__ int64_t *)input_tokens_ptr + offset_this_core, this->actual_task_per_core); + sampledTokenIdsGm.SetGlobalBuffer((__gm__ int64_t *)sampled_token_ids_ptr + offset_this_core, this->actual_task_per_core); + inputPositionsGm.SetGlobalBuffer((__gm__ int64_t *)input_positions_ptr + offset_this_core, this->actual_task_per_core); + seqLensGm.SetGlobalBuffer((__gm__ int32_t *)seq_lens_ptr + offset_this_core, this->actual_task_per_core); + slotMappingGm.SetGlobalBuffer((__gm__ int32_t *)slot_mapping_ptr + offset_this_core, this->actual_task_per_core); + } + __aicore__ inline void Process(int64_t block_size, __gm__ int32_t* block_tables_ptr, int64_t block_tables_stride) + { + // no need for tilling or pipeline parallel within each core, as the amount of data processed is very small + CopyIn(); + Update(block_size, block_tables_ptr, block_tables_stride); + CopyOut(); + } + +private: + __aicore__ inline void CopyIn() + { + AscendC::LocalTensor seqLenLocalIn = inQueSeqLen.AllocTensor(); + AscendC::LocalTensor sampledTokenIdsLocal = inQueSampledTokenIds.AllocTensor(); + + AscendC::DataCopyExtParams copyParams32{1, static_cast(this->actual_task_per_core * sizeof(int32_t)), 0, 0, 0}; // blockLen = tasks_per_core * 32 / 8 个字节(int32为4字节) + AscendC::DataCopyExtParams copyParams64{1, static_cast(this->actual_task_per_core * sizeof(int64_t)), 0, 0, 0}; // blockLen = tasks_per_core * 64 / 8 个字节(int64为8字节) + + // calculate the nums that need padded + // so that the total length becomes a multiple of 32 bytes which is a requirement of DataCopy Function. + uint8_t remainNum32 =this->actual_task_per_core * sizeof(int32_t) % 32; + uint8_t needPadElements32 = remainNum32 == 0 ? remainNum32 : (32 - remainNum32) / sizeof(int32_t); + + AscendC::DataCopyPadExtParams padParams32{true, 0, needPadElements32, 0}; + + // calculate the nums that need padded + // so that the total length becomes a multiple of 32 bytes which is a requirement of DataCopy Function. + uint8_t remainNum64 =this->actual_task_per_core * sizeof(int64_t) % 32; + uint8_t needPadElements64 = remainNum64 == 0 ? remainNum64 : (32 - remainNum64) / sizeof(int64_t); + AscendC::DataCopyPadExtParams padParams64{true, 0, needPadElements64, 0}; + + AscendC::DataCopyPad(seqLenLocalIn, seqLensGm, copyParams32, padParams32); + AscendC::DataCopyPad(sampledTokenIdsLocal, sampledTokenIdsGm, copyParams64, padParams64); + + inQueSeqLen.EnQue(seqLenLocalIn); + inQueSampledTokenIds.EnQue(sampledTokenIdsLocal); + } + __aicore__ inline void Update(int64_t block_size, __gm__ int32_t* block_tables_ptr, int64_t block_tables_stride) + { + // input + AscendC::LocalTensor seqLenLocalIn = inQueSeqLen.DeQue(); + AscendC::LocalTensor sampledTokenIdsLocal = inQueSampledTokenIds.DeQue(); + + // output + AscendC::LocalTensor inputTokensLocal = outQueInputTokens.AllocTensor(); + AscendC::LocalTensor inputPosLocal = outQueInputPos.AllocTensor(); + AscendC::LocalTensor seqLenLocalOut = outQueSeqLen.AllocTensor(); + AscendC::LocalTensor slotMappingLocal = outQueSlotMapping.AllocTensor(); + + auto unary_params = AscendC::UnaryRepeatParams(1, 1, 8, 8); + + //Use "for" instead of AscendC::Adds function because AscendC::Adds does not work + //when srcLocalMemory has different datatype from dstLocalMemory + for(int i=0; i < this->actual_task_per_core; i++) { + inputTokensLocal.SetValue(i, sampledTokenIdsLocal.GetValue(i)); + inputPosLocal.SetValue(i, seqLenLocalIn.GetValue(i)); + } + + AscendC::Adds(seqLenLocalOut, seqLenLocalIn, 1, (uint64_t)0, 1, unary_params); + + // Gather blockTables with dim=1, block_index. No Ascend Function available, use "for" instead. + for(int cur_query_id = this->start_id, i = 0; i < this->actual_task_per_core; cur_query_id++, i++) { + __gm__ int32_t const* seq_block_tables_ptr = block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = inputPosLocal.GetValue(i) / block_size; + int block_offset = inputPosLocal.GetValue(i) % block_size; + + int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slotMappingLocal.SetValue(i,slot_num); + } + + outQueInputTokens.EnQue(inputTokensLocal); + outQueInputPos.EnQue(inputPosLocal); + outQueSeqLen.EnQue(seqLenLocalOut); + outQueSlotMapping.EnQue(slotMappingLocal); + + inQueSampledTokenIds.FreeTensor(sampledTokenIdsLocal); + inQueSeqLen.FreeTensor(seqLenLocalIn); + + } + __aicore__ inline void CopyOut() + { + AscendC::DataCopyExtParams copyParams32{1, static_cast(this->actual_task_per_core * sizeof(int32_t)),0,0,0}; + AscendC::DataCopyExtParams copyParams64{1, static_cast(this->actual_task_per_core * sizeof(int64_t)),0,0,0}; + + AscendC::LocalTensor inputTokensLocal = outQueInputTokens.DeQue(); + AscendC::DataCopyPad(inputTokensGm, inputTokensLocal, copyParams64); + outQueInputTokens.FreeTensor(inputTokensLocal); + + AscendC::LocalTensor inputPosLocal = outQueInputPos.DeQue(); + AscendC::DataCopyPad(inputPositionsGm, inputPosLocal, copyParams64); + outQueInputPos.FreeTensor(inputPosLocal); + + AscendC::LocalTensor seqLenLocalOut = outQueSeqLen.DeQue(); + AscendC::DataCopyPad(seqLensGm, seqLenLocalOut, copyParams32); + outQueSeqLen.FreeTensor(seqLenLocalOut); + + AscendC::LocalTensor slotMappingLocal = outQueSlotMapping.DeQue(); + AscendC::DataCopyPad(slotMappingGm, slotMappingLocal, copyParams32); + outQueSlotMapping.FreeTensor(slotMappingLocal); + } + +private: + AscendC::TPipe pipe; + AscendC::TQue outQueInputTokens, outQueInputPos, + outQueSeqLen, outQueSlotMapping; + AscendC::TQue inQueSeqLen, + inQueSampledTokenIds, + inQueBlockTables; + + AscendC::GlobalTensor inputTokensGm, sampledTokenIdsGm, inputPositionsGm ; + + AscendC::GlobalTensor seqLensGm, slotMappingGm, blockTablesGm; + + int32_t tasks_per_core, start_id, end_id, actual_task_per_core; +}; + +extern "C" __global__ __aicore__ void AdvanceStepFlashAttnKernel( + int64_t num_seqs, + int64_t num_queries, + int64_t block_size, + __gm__ int64_t* input_tokens_ptr, + __gm__ int64_t* sampled_token_ids_ptr, + __gm__ int64_t* input_positions_ptr, + __gm__ int32_t* seq_lens_ptr, + __gm__ int32_t* slot_mapping_ptr, + __gm__ int32_t* block_tables_ptr, + int64_t block_tables_stride, + int32_t tasks_per_core +) +{ + int start_id = tasks_per_core * AscendC::GetBlockIdx(); + // no task for this core. + if(start_id >= num_queries) { + return; + } + KernelAdvanceStep advanceStep; + advanceStep.Init(tasks_per_core, num_queries, input_tokens_ptr, sampled_token_ids_ptr, input_positions_ptr, seq_lens_ptr, slot_mapping_ptr); + advanceStep.Process(block_size,block_tables_ptr,block_tables_stride); +} + +namespace vllm_ascend +{ + +extern void launch_advance_step_flashattn( + void* stream, + int64_t num_seqs, + int64_t num_queries, + int64_t block_size, + int64_t* input_tokens_ptr, + int64_t* sampled_token_ids_ptr, + int64_t* input_positions_ptr, + int32_t* seq_lens_ptr, + int32_t* slot_mapping_ptr, + int32_t* block_tables_ptr, + int64_t block_tables_stride) +{ + int32_t num_cores = 20; + + if(num_cores > num_queries) { + num_cores = num_queries; + } + + // task num processed of each core + int32_t tasks_per_core = (num_queries + num_cores - 1) / num_cores; + + AdvanceStepFlashAttnKernel<<>>( + num_seqs, + num_queries, + block_size, + input_tokens_ptr, + sampled_token_ids_ptr, + input_positions_ptr, + seq_lens_ptr, + slot_mapping_ptr, + block_tables_ptr, + block_tables_stride, + tasks_per_core); +} + +} diff --git a/csrc/ops.h b/csrc/ops.h index aaac630..b921b2b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -46,4 +46,16 @@ namespace vllm_ascend { auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options); return new_tensor; } + extern void launch_advance_step_flashattn( + void* stream, + int64_t num_seqs, + int64_t num_queries, + int64_t block_size, + int64_t* input_tokens_ptr, + int64_t* sampled_token_ids_ptr, + int64_t* input_positions_ptr, + int32_t* seq_lens_ptr, + int32_t* slot_mapping_ptr, + int32_t* block_tables_ptr, + int64_t block_tables_stride); } diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 94e1fd6..c415438 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -98,6 +98,87 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T cmd.Run(); return {query_dst, key_dst}; } + +void verify_tensor(std::string const& name, at::Tensor const& t, + int64_t const size_0, int64_t const size_1, + c10::ScalarType const type) { + bool size_0_cond = true; + if (size_0 != -1) { + size_0_cond = t.size(0) == size_0; + } + + bool size_1_cond = true; + if (size_1 != -1) { + size_1_cond = t.size(1) == size_1; + } + + bool is_contiguous = t.is_contiguous(); + bool same_type = t.dtype() == type; + + bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; + if (!pass) { + TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), + " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), + " is not as expected: shape = [", size_0, ", ", size_1, + "], type = ", type); + } +} + + +void advance_step_flashattn_ascendc( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + at::Tensor& input_tokens, + at::Tensor& sampled_token_ids, + at::Tensor& input_positions, + at::Tensor& seq_lens, + at::Tensor& slot_mapping, + at::Tensor& block_tables +){ + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kInt); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + + int64_t* input_tokens_ptr = input_tokens.data_ptr(); + int64_t* sampled_token_ids_ptr = sampled_token_ids.data_ptr(); + int64_t* input_positions_ptr = input_positions.data_ptr(); + int32_t* seq_lens_ptr = seq_lens.data_ptr(); + int32_t* slot_mapping_ptr = slot_mapping.data_ptr(); + int32_t* block_tables_ptr = block_tables.data_ptr(); + + + int32_t device_id; + aclrtGetDevice(&device_id); + auto npu_stream = c10_npu::getCurrentNPUStream(device_id); + aclrtStream stream = npu_stream.stream(); + + // aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at_npu::native::OpCommand cmd; + cmd.Name("advance_step_flashattn_ascendc"); + cmd.SetCustomHandler([stream, num_seqs, num_queries, + block_size, input_tokens_ptr, sampled_token_ids_ptr, + input_positions_ptr, seq_lens_ptr, slot_mapping_ptr, + block_tables_ptr, block_tables]() -> int { + launch_advance_step_flashattn(stream, + num_seqs, + num_queries, + block_size, + input_tokens_ptr, + sampled_token_ids_ptr, + input_positions_ptr, + seq_lens_ptr, + slot_mapping_ptr, + block_tables_ptr, + block_tables.stride(0)); + return 0; + }); + cmd.Run(); + return ; +} } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(_C, ops) @@ -113,6 +194,11 @@ TORCH_LIBRARY_EXPAND(_C, ops) " Tensor! key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)"); ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding); + ops.def( + "advance_step_flashattn_ascendc(int num_seqs, int num_queries, int block_size," + " Tensor! input_tokens, Tensor! sampled_token_ids, Tensor! input_positions," + " Tensor! seq_lens, Tensor! slot_mapping, Tensor! block_tables) -> ()"); + ops.impl("advance_step_flashattn_ascendc", torch::kPrivateUse1, &vllm_ascend::advance_step_flashattn_ascendc); } REGISTER_EXTENSION(_C) diff --git a/examples/offline_multi_step_custom_ops.py b/examples/offline_multi_step_custom_ops.py new file mode 100644 index 0000000..82a1bf5 --- /dev/null +++ b/examples/offline_multi_step_custom_ops.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2025 China Merchants Bank Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/examples/offline_inference/basic.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vllm import LLM, SamplingParams + +import vllm_ascend.platform as pf + +pf.CUSTOM_OP_ENABLED = True # set True for custom Ops of Multi-Step. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "China is", +] + +# Create a sampling params object. +sampling_params = SamplingParams(max_tokens=100, temperature=0.0) +# Create an LLM. +llm = LLM( + model="Qwen/Qwen2.5-0.5B", + block_size=128, + max_model_len=1024, # max length of prompt + tensor_parallel_size=1, # number of NPUs to be used + max_num_seqs=26, # max batch number + enforce_eager= + True, # Force PyTorch eager execution to debug intermediate tensors (disables graph optimizations) + trust_remote_code= + True, # If the model is a cuscd tom model not yet available in the HuggingFace transformers library + num_scheduler_steps=8, + gpu_memory_utilization=0.5) + +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/format.sh b/format.sh index 595bf2f..608c700 100755 --- a/format.sh +++ b/format.sh @@ -145,7 +145,7 @@ CODESPELL_EXCLUDES=( ) CODESPELL_IGNORE_WORDS=( - '-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue' + '-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn' ) # check spelling of specified files diff --git a/tests/ops/test_multi_step.py b/tests/ops/test_multi_step.py new file mode 100644 index 0000000..5eea7ad --- /dev/null +++ b/tests/ops/test_multi_step.py @@ -0,0 +1,190 @@ +# Copyright (c) China Merchants Bank Co., Ltd. 2025. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#/ + +# to run this test, you need to cd to the upper package which is 'tests', +# and run with command 'pytest -s ops/test_multi_step.py' + +import torch +import torch_npu # noqa: F401 + +DTYPES = [torch.int32, torch.int64] +DEVICES = [f"npu:{0}"] +# Set tolerance to 0 for equals +DEFAULT_ATOL = 0 +DEFAULT_RTOL = 0 + +# test custom ops of https://github.com/vllm-project/vllm-ascend/tree/main/csrc/kernels/advance_step.cpp + + +@torch.inference_mode() +def test_single_generation_multi_step() -> None: + input_tokens_data = [2926] + input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0') + input_tokens_python = torch.tensor(input_tokens_data, device='npu:0') + + sampled_token_ids_data = [[13]] + sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0') + + input_positions_data = [5] + input_positions_ascendc = torch.tensor(input_positions_data, + device='npu:0') + input_positions_python = torch.tensor(input_positions_data, device='npu:0') + + seq_lens_data = [6] + seq_lens_ascendc = torch.tensor(seq_lens_data, + device='npu:0', + dtype=torch.int32) + seq_lens_python = torch.tensor(seq_lens_data, + device='npu:0', + dtype=torch.int32) + + slot_mapping_data = [5] + slot_mapping_ascendc = torch.tensor(slot_mapping_data, + device='npu:0', + dtype=torch.int32) + slot_mapping_python = torch.tensor(slot_mapping_data, + device='npu:0', + dtype=torch.int32) + + block_tables_data = [[0]] + + block_tables = torch.tensor(block_tables_data, + device='npu:0', + dtype=torch.int32) + + torch.ops._C.advance_step_flashattn_ascendc( + 1, 1, 128, input_tokens_ascendc, sampled_token_ids, + input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc, + block_tables) + + normal(1, 1, 128, input_tokens_python, sampled_token_ids, + input_positions_python, seq_lens_python, slot_mapping_python, + block_tables) + + # Compare the results. + torch.testing.assert_close(input_tokens_ascendc, + input_tokens_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(input_positions_ascendc, + input_positions_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(seq_lens_ascendc, + seq_lens_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(slot_mapping_ascendc, + slot_mapping_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + +@torch.inference_mode() +def test_multi_result_generation_multi_step() -> None: + input_tokens_data = [2926, 279, 12095, 1588] + input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0') + input_tokens_python = torch.tensor(input_tokens_data, device='npu:0') + + sampled_token_ids_data = [[13], [1968], [13], [13]] + sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0') + + input_positions_data = [5, 7, 5, 5] + input_positions_ascendc = torch.tensor(input_positions_data, + device='npu:0') + input_positions_python = torch.tensor(input_positions_data, device='npu:0') + + seq_lens_data = [6, 8, 6, 6] + seq_lens_ascendc = torch.tensor(seq_lens_data, + device='npu:0', + dtype=torch.int32) + seq_lens_python = torch.tensor(seq_lens_data, + device='npu:0', + dtype=torch.int32) + + slot_mapping_data = [5, 135, 261, 389] + slot_mapping_ascendc = torch.tensor(slot_mapping_data, + device='npu:0', + dtype=torch.int32) + slot_mapping_python = torch.tensor(slot_mapping_data, + device='npu:0', + dtype=torch.int32) + + block_tables_data = [[0], [1], [2], [3]] + + block_tables = torch.tensor(block_tables_data, + device='npu:0', + dtype=torch.int32) + + torch.ops._C.advance_step_flashattn_ascendc( + 4, 4, 128, input_tokens_ascendc, sampled_token_ids, + input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc, + block_tables) + + normal(4, 4, 128, input_tokens_python, sampled_token_ids, + input_positions_python, seq_lens_python, slot_mapping_python, + block_tables) + + # Compare the results. + torch.testing.assert_close(input_tokens_ascendc, + input_tokens_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(input_positions_ascendc, + input_positions_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(seq_lens_ascendc, + seq_lens_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(slot_mapping_ascendc, + slot_mapping_python, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + +def normal(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, seq_lens_tensor: torch.Tensor, + slot_mapping: torch.Tensor, block_tables: torch.Tensor) -> None: + sampled_token_ids_list = sampled_token_ids[:num_queries].squeeze(-1) + input_tokens[:num_queries] = sampled_token_ids_list + + # get seq_lens and input_positions + seq_lens = seq_lens_tensor[:num_queries] + next_seq_lens = seq_lens + 1 + next_input_pos = next_seq_lens - 1 + + # update seq_lens and input_positions + seq_lens_tensor[:num_queries] = next_seq_lens + input_positions[:num_queries] = next_input_pos # type: ignore + + # get block index and offset + block_idx = next_input_pos // block_size + block_offset = next_input_pos % block_size + + current_block_table = block_tables.gather( + 1, block_idx.unsqueeze(-1)).squeeze(-1) + slot_num = current_block_table * block_size + block_offset + + # update slot_mapping + slot_mapping[:num_queries] = slot_num diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index d598822..659aa60 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -36,6 +36,7 @@ from vllm.config import get_current_vllm_config from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm_ascend.ops.cache import concat_and_cache_mla +from vllm_ascend.platform import CUSTOM_OP_ENABLED from vllm_ascend.worker.model_runner import ( ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) @@ -459,36 +460,47 @@ class AscendMetadata(AttentionMetadata): for i in range(num_queries): self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) + if CUSTOM_OP_ENABLED: + #advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled + torch.ops._C.advance_step_flashattn_ascendc( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + else: + # use traditional Pytorch method for updating these tensors. + # update input_tokens + sampled_token_ids_list = sampled_token_ids[: + num_queries].squeeze( # type: ignore + -1) + model_input.input_tokens[: + num_queries] = sampled_token_ids_list # type: ignore - # TODO optimize these codes using ascendc just like flash attention backend using cuda + # get seq_lens and input_positions + seq_lens = self.seq_lens_tensor[:num_queries] + next_seq_lens = seq_lens + 1 + next_input_pos = next_seq_lens - 1 - # update input_tokens - sampled_token_ids_list = sampled_token_ids[: - num_queries].squeeze( # type: ignore - -1) - model_input.input_tokens[: - num_queries] = sampled_token_ids_list # type: ignore + # update seq_lens and input_positions + self.seq_lens_tensor[:num_queries] = next_seq_lens + model_input.input_positions[: + num_queries] = next_input_pos # type: ignore - # get seq_lens and input_positions - seq_lens = self.seq_lens_tensor[:num_queries] - next_seq_lens = seq_lens + 1 - next_input_pos = next_seq_lens - 1 + # 计算 block index 和 offset + block_idx = next_input_pos // block_size + block_offset = next_input_pos % block_size - # update seq_lens and input_positions - self.seq_lens_tensor[:num_queries] = next_seq_lens - model_input.input_positions[: - num_queries] = next_input_pos # type: ignore + current_block_table = self.block_tables.gather( + 1, block_idx.unsqueeze(-1)).squeeze(-1) + slot_num = current_block_table * block_size + block_offset - # 计算 block index 和 offset - block_idx = next_input_pos // block_size - block_offset = next_input_pos % block_size - - current_block_table = self.block_tables.gather( - 1, block_idx.unsqueeze(-1)).squeeze(-1) - slot_num = current_block_table * block_size + block_offset - - # update slot_mapping - self.slot_mapping[:num_queries] = slot_num + # update slot_mapping + self.slot_mapping[:num_queries] = slot_num class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): @@ -749,11 +761,11 @@ class AscendAttentionBackendImpl(AttentionImpl): key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache: shape = [2, num_blocks, block_size, - num_kv_heads * head_size] + num_kv_heads, head_size] key_cache = [num_blocks, block_size, - num_kv_heads * head_size] + num_kv_heads, head_size] value_cache = [num_blocks, block_size, - num_kv_heads * head_size] + num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [batch_size, seq_len * num_heads * head_size] diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e8ee5ae..5f122f6 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -220,11 +220,11 @@ class AscendAttentionBackendImpl(AttentionImpl): key: shape = [batch_size, seq_len, num_kv_heads * head_size] value: shape = [batch_size, seq_len, num_kv_heads * head_size] kv_cache: shape = [2, num_blocks, block_size, - num_kv_heads * head_size] + num_kv_heads, head_size] key_cache = [num_blocks, block_size, - num_kv_heads * head_size] + num_kv_heads, head_size] value_cache = [num_blocks, block_size, - num_kv_heads * head_size] + num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [batch_size * seq_len, num_heads, head_size] diff --git a/vllm_ascend/worker/multi_step_runner.py b/vllm_ascend/worker/multi_step_runner.py index ac2b685..028bcd0 100644 --- a/vllm_ascend/worker/multi_step_runner.py +++ b/vllm_ascend/worker/multi_step_runner.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import current_stream from vllm.worker.model_runner_base import ( _init_attn_metadata_from_tensor_dict, _init_frozen_model_input_from_tensor_dict, @@ -23,6 +22,7 @@ from vllm.worker.multi_step_model_runner import (ModelOutput, PythonizationCache, StatefulModelInput) +from vllm_ascend.utils import current_stream from vllm_ascend.worker.model_runner import ( ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)