[Performance]: Custom AscendC Kernel of Multi-Step Prepare Input (#814)
### What this PR does / why we need it? - According to https://github.com/vllm-project/vllm-ascend/issues/807, we pull request for customer ascendc kernel of multi-step. - also a bug we found in multi_step_runner.py is fixed when we use multi-step on V0 Engine. ### Does this PR introduce _any_ user-facing change? no user-facing change ### How was this patch tested? we add Unit Test file and offline inference file to test the custom ascendc kernel. See test/ops/test_multi_step.py and examples/offline_multi_step.py --------- Signed-off-by: wan_danfeng <wonderful199082@126.com>
This commit is contained in:
@@ -36,6 +36,7 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
from vllm_ascend.ops.cache import concat_and_cache_mla
|
||||
from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
@@ -459,36 +460,47 @@ class AscendMetadata(AttentionMetadata):
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
if CUSTOM_OP_ENABLED:
|
||||
#advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
|
||||
torch.ops._C.advance_step_flashattn_ascendc(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
else:
|
||||
# use traditional Pytorch method for updating these tensors.
|
||||
# update input_tokens
|
||||
sampled_token_ids_list = sampled_token_ids[:
|
||||
num_queries].squeeze( # type: ignore
|
||||
-1)
|
||||
model_input.input_tokens[:
|
||||
num_queries] = sampled_token_ids_list # type: ignore
|
||||
|
||||
# TODO optimize these codes using ascendc just like flash attention backend using cuda
|
||||
# get seq_lens and input_positions
|
||||
seq_lens = self.seq_lens_tensor[:num_queries]
|
||||
next_seq_lens = seq_lens + 1
|
||||
next_input_pos = next_seq_lens - 1
|
||||
|
||||
# update input_tokens
|
||||
sampled_token_ids_list = sampled_token_ids[:
|
||||
num_queries].squeeze( # type: ignore
|
||||
-1)
|
||||
model_input.input_tokens[:
|
||||
num_queries] = sampled_token_ids_list # type: ignore
|
||||
# update seq_lens and input_positions
|
||||
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
||||
model_input.input_positions[:
|
||||
num_queries] = next_input_pos # type: ignore
|
||||
|
||||
# get seq_lens and input_positions
|
||||
seq_lens = self.seq_lens_tensor[:num_queries]
|
||||
next_seq_lens = seq_lens + 1
|
||||
next_input_pos = next_seq_lens - 1
|
||||
# 计算 block index 和 offset
|
||||
block_idx = next_input_pos // block_size
|
||||
block_offset = next_input_pos % block_size
|
||||
|
||||
# update seq_lens and input_positions
|
||||
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
||||
model_input.input_positions[:
|
||||
num_queries] = next_input_pos # type: ignore
|
||||
current_block_table = self.block_tables.gather(
|
||||
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
||||
slot_num = current_block_table * block_size + block_offset
|
||||
|
||||
# 计算 block index 和 offset
|
||||
block_idx = next_input_pos // block_size
|
||||
block_offset = next_input_pos % block_size
|
||||
|
||||
current_block_table = self.block_tables.gather(
|
||||
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
||||
slot_num = current_block_table * block_size + block_offset
|
||||
|
||||
# update slot_mapping
|
||||
self.slot_mapping[:num_queries] = slot_num
|
||||
# update slot_mapping
|
||||
self.slot_mapping[:num_queries] = slot_num
|
||||
|
||||
|
||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
@@ -749,11 +761,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
num_kv_heads, head_size]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len * num_heads * head_size]
|
||||
|
||||
@@ -220,11 +220,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
num_kv_heads, head_size]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
|
||||
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import current_stream
|
||||
from vllm.worker.model_runner_base import (
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_frozen_model_input_from_tensor_dict,
|
||||
@@ -23,6 +22,7 @@ from vllm.worker.multi_step_model_runner import (ModelOutput,
|
||||
PythonizationCache,
|
||||
StatefulModelInput)
|
||||
|
||||
from vllm_ascend.utils import current_stream
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user