[Refactor] move the metadata from attention_v1 to util(ready for extract common_cp) & realize Ascendmetadata inherit from the parent class. (#5203)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
1. Remove the pcp-related code from attention_v1.
2. Establish the inheritance relationship of CommonAttentionMetadata.
TODO
1. extract common_cp
2. move cp metadata to common_cp.
3. remove commonAttentionMetadata for aclgraph.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -117,7 +117,8 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
spec_attn_mask=None,
|
spec_attn_mask=None,
|
||||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None)
|
seq_lens=None,
|
||||||
|
max_seq_len=6)
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
|
|
||||||
self.builder.build(1, common_attn_metadata, mock_model)
|
self.builder.build(1, common_attn_metadata, mock_model)
|
||||||
|
|||||||
@@ -606,7 +606,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
spec_attn_mask=None,
|
spec_attn_mask=None,
|
||||||
attn_state=AscendAttentionState.PrefillNoCache,
|
attn_state=AscendAttentionState.PrefillNoCache,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None)
|
seq_lens=None,
|
||||||
|
max_seq_len=6)
|
||||||
|
|
||||||
base_inputs = {
|
base_inputs = {
|
||||||
"num_actual_tokens": 10,
|
"num_actual_tokens": 10,
|
||||||
@@ -673,7 +674,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
spec_attn_mask=None,
|
spec_attn_mask=None,
|
||||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None)
|
seq_lens=None,
|
||||||
|
max_seq_len=6)
|
||||||
|
|
||||||
base_inputs = {
|
base_inputs = {
|
||||||
"num_actual_tokens": 15,
|
"num_actual_tokens": 15,
|
||||||
@@ -729,7 +731,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
spec_attn_mask=None,
|
spec_attn_mask=None,
|
||||||
attn_state=AscendAttentionState.DecodeOnly,
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None)
|
seq_lens=None,
|
||||||
|
max_seq_len=6)
|
||||||
|
|
||||||
base_inputs = {
|
base_inputs = {
|
||||||
"num_actual_tokens": 3,
|
"num_actual_tokens": 3,
|
||||||
@@ -784,7 +787,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
spec_attn_mask=None,
|
spec_attn_mask=None,
|
||||||
attn_state=AscendAttentionState.DecodeOnly,
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None)
|
seq_lens=None,
|
||||||
|
max_seq_len=6)
|
||||||
|
|
||||||
base_inputs = {
|
base_inputs = {
|
||||||
"num_actual_tokens": 3,
|
"num_actual_tokens": 3,
|
||||||
@@ -839,7 +843,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
spec_attn_mask=None,
|
spec_attn_mask=None,
|
||||||
attn_state=AscendAttentionState.PrefillNoCache,
|
attn_state=AscendAttentionState.PrefillNoCache,
|
||||||
num_computed_tokens_cpu=None,
|
num_computed_tokens_cpu=None,
|
||||||
seq_lens=None)
|
seq_lens=None,
|
||||||
|
max_seq_len=6)
|
||||||
|
|
||||||
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
||||||
layer_names=["layer_0", "layer_1"],
|
layer_names=["layer_0", "layer_1"],
|
||||||
|
|||||||
@@ -33,10 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
|
|
||||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
|
||||||
AscendAttentionMetadataBuilder,
|
AscendAttentionMetadataBuilder,
|
||||||
AscendMetadata,
|
AscendMetadata)
|
||||||
AscendMetadataForDecode,
|
|
||||||
AscendMetadataForPrefill)
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
AscendMetadataForDecode,
|
||||||
|
AscendMetadataForPrefill,
|
||||||
filter_chunked_req_indices,
|
filter_chunked_req_indices,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
enable_cp, split_decodes_and_prefills,
|
AscendMetadataForDecode,
|
||||||
|
AscendMetadataForPrefill, enable_cp,
|
||||||
|
split_decodes_and_prefills,
|
||||||
using_paged_attention)
|
using_paged_attention)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
@@ -118,51 +120,6 @@ class AscendAttentionState(Enum):
|
|||||||
SpecDecoding = 4
|
SpecDecoding = 4
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendMetadataForPrefill:
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendPCPMetadata:
|
|
||||||
q_head_idx: torch.Tensor = None
|
|
||||||
q_tail_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
|
||||||
attn_mask_seqlens: torch.Tensor = None
|
|
||||||
head_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
q_full_idx: torch.Tensor = None
|
|
||||||
pcp_prefill_mask: torch.Tensor = None
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChunkedContextMetadata:
|
|
||||||
actual_chunk_seq_lengths: torch.Tensor
|
|
||||||
actual_seq_lengths_kv: torch.Tensor
|
|
||||||
starts: torch.Tensor
|
|
||||||
chunk_seq_mask_filtered_indices: torch.Tensor
|
|
||||||
chunked_req_mask: Optional[list[bool]] = None
|
|
||||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
|
||||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
|
||||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
|
||||||
batch_chunk_seq_mask: Optional[list[bool]] = None
|
|
||||||
|
|
||||||
""" Prefill Specific Metadata for Ascend"""
|
|
||||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
|
||||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
|
||||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
|
||||||
block_tables: torch.Tensor = None
|
|
||||||
actual_seq_lengths_q: torch.Tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendMetadataForDecode:
|
|
||||||
""" Decode Specific Metadata for Ascend"""
|
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
|
||||||
batch_seq_mask: torch.Tensor = None
|
|
||||||
block_tables: torch.Tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
# **************************** Basic Properties ************************** #
|
# **************************** Basic Properties ************************** #
|
||||||
@@ -274,14 +231,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
block_table = common_attn_metadata.block_table_tensor
|
block_table = common_attn_metadata.block_table_tensor
|
||||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
|
|
||||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
|
||||||
if num_actual_tokens_pcp_padded is None:
|
|
||||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
|
||||||
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
|
||||||
num_actual_tokens_pcp_padded]
|
|
||||||
|
|
||||||
attn_mask = common_attn_metadata.attn_mask
|
attn_mask = common_attn_metadata.attn_mask
|
||||||
attn_state = common_attn_metadata.attn_state
|
attn_state = common_attn_metadata.attn_state
|
||||||
|
|
||||||
@@ -292,7 +242,6 @@ class AscendAttentionMetadataBuilder:
|
|||||||
attn_metadata = AscendMetadata(
|
attn_metadata = AscendMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
is_v1_kv_transfer_group)
|
is_v1_kv_transfer_group)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
|
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
|
||||||
get_ascend_device_type)
|
get_ascend_device_type)
|
||||||
@@ -34,6 +35,51 @@ def enable_cp():
|
|||||||
or prefill_config.decode_context_parallel_size > 1
|
or prefill_config.decode_context_parallel_size > 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendMetadataForPrefill:
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendPCPMetadata:
|
||||||
|
q_head_idx: torch.Tensor = None
|
||||||
|
q_tail_idx: torch.Tensor = None
|
||||||
|
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_head_mask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_tail_mask_idx: torch.Tensor = None
|
||||||
|
attn_mask_seqlens: torch.Tensor = None
|
||||||
|
head_attn_nomask_seqlens: torch.Tensor = None
|
||||||
|
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||||
|
q_full_idx: torch.Tensor = None
|
||||||
|
pcp_prefill_mask: torch.Tensor = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkedContextMetadata:
|
||||||
|
actual_chunk_seq_lengths: torch.Tensor
|
||||||
|
actual_seq_lengths_kv: torch.Tensor
|
||||||
|
starts: torch.Tensor
|
||||||
|
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||||
|
chunked_req_mask: Optional[list[bool]] = None
|
||||||
|
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||||
|
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||||
|
|
||||||
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
|
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||||
|
pcp_allgather_restore_idx: Optional[List[int]] = None
|
||||||
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
actual_seq_lengths_q: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendMetadataForDecode:
|
||||||
|
""" Decode Specific Metadata for Ascend"""
|
||||||
|
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||||
|
batch_seq_mask: torch.Tensor = None
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# class AscendCommonLongSequenceMetadata:
|
# class AscendCommonLongSequenceMetadata:
|
||||||
class AscendPrefillContextParallelMetadata:
|
class AscendPrefillContextParallelMetadata:
|
||||||
@@ -75,45 +121,20 @@ class AscendPrefillContextParallelMetadata:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendCommonAttentionMetadata:
|
class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||||
"""
|
"""
|
||||||
Per-batch attention metadata, shared across layers and backends.
|
Per-batch attention metadata, shared across layers and backends.
|
||||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||||
|
|
||||||
For many of the tensors we keep both NPU and CPU versions.
|
For many of the tensors we keep both NPU and CPU versions.
|
||||||
"""
|
"""
|
||||||
|
seq_lens_cpu: torch.Tensor = None
|
||||||
|
num_computed_tokens_cpu: torch.Tensor = None
|
||||||
|
|
||||||
query_start_loc: torch.Tensor
|
decode_token_per_req: int = 1
|
||||||
query_start_loc_cpu: torch.Tensor
|
|
||||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
|
||||||
|
|
||||||
seq_lens_cpu: torch.Tensor
|
|
||||||
"""(batch_size,), the length of each request including both computed tokens
|
|
||||||
and newly scheduled tokens"""
|
|
||||||
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
|
||||||
(such as GDN)."""
|
|
||||||
|
|
||||||
num_computed_tokens_cpu: torch.Tensor
|
|
||||||
"""(batch_size,), the number of computed tokens for each request"""
|
|
||||||
|
|
||||||
num_reqs: int
|
|
||||||
"""Number of requests"""
|
|
||||||
num_actual_tokens: int
|
|
||||||
"""Total number of tokens in batch"""
|
|
||||||
|
|
||||||
max_query_len: int
|
|
||||||
"""Max token number of request in batch"""
|
|
||||||
|
|
||||||
decode_token_per_req: int
|
|
||||||
"""decode token number per request"""
|
"""decode token number per request"""
|
||||||
|
|
||||||
block_table_tensor: torch.Tensor
|
actual_seq_lengths_q: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
slot_mapping: torch.Tensor
|
|
||||||
|
|
||||||
actual_seq_lengths_q: list[int]
|
|
||||||
|
|
||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
@@ -132,8 +153,6 @@ class AscendCommonAttentionMetadata:
|
|||||||
prefill_context_parallel_metadata: Optional[
|
prefill_context_parallel_metadata: Optional[
|
||||||
AscendPrefillContextParallelMetadata] = None
|
AscendPrefillContextParallelMetadata] = None
|
||||||
|
|
||||||
causal: bool = True
|
|
||||||
|
|
||||||
# TODO: Remove it when vLLM no longer uses this function.
|
# TODO: Remove it when vLLM no longer uses this function.
|
||||||
def unpadded(self, num_actual_tokens: int,
|
def unpadded(self, num_actual_tokens: int,
|
||||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||||
@@ -161,7 +180,7 @@ class AscendCommonAttentionMetadata:
|
|||||||
num_input_tokens=num_actual_tokens,
|
num_input_tokens=num_actual_tokens,
|
||||||
prefill_context_parallel_metadata=self.
|
prefill_context_parallel_metadata=self.
|
||||||
prefill_context_parallel_metadata,
|
prefill_context_parallel_metadata,
|
||||||
)
|
max_seq_len=self.max_seq_len)
|
||||||
|
|
||||||
|
|
||||||
def filter_chunked_req_indices(
|
def filter_chunked_req_indices(
|
||||||
|
|||||||
@@ -742,7 +742,7 @@ class EagleProposer(Proposer):
|
|||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
)
|
max_seq_len=0)
|
||||||
return spec_common_attn_metadata, token_indices
|
return spec_common_attn_metadata, token_indices
|
||||||
|
|
||||||
def prepare_inputs_padded(
|
def prepare_inputs_padded(
|
||||||
@@ -800,7 +800,8 @@ class EagleProposer(Proposer):
|
|||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
num_computed_tokens_cpu=common_attn_metadata.
|
num_computed_tokens_cpu=common_attn_metadata.
|
||||||
num_computed_tokens_cpu,
|
num_computed_tokens_cpu,
|
||||||
seq_lens=common_attn_metadata.seq_lens)
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
|
max_seq_len=0)
|
||||||
|
|
||||||
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
|
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
|
||||||
1 - num_rejected_tokens_gpu)
|
1 - num_rejected_tokens_gpu)
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ class MtpProposer(Proposer):
|
|||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
)
|
max_seq_len=0)
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
# update long_seq related params and flatten block_table
|
# update long_seq related params and flatten block_table
|
||||||
common_attn_metadata.prefill_context_parallel_metadata = \
|
common_attn_metadata.prefill_context_parallel_metadata = \
|
||||||
@@ -599,7 +599,7 @@ class MtpProposer(Proposer):
|
|||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
)
|
max_seq_len=0)
|
||||||
return spec_common_attn_metadata, token_indices
|
return spec_common_attn_metadata, token_indices
|
||||||
|
|
||||||
def _propose(
|
def _propose(
|
||||||
@@ -1221,7 +1221,8 @@ class MtpProposer(Proposer):
|
|||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
num_computed_tokens_cpu=common_attn_metadata.
|
num_computed_tokens_cpu=common_attn_metadata.
|
||||||
num_computed_tokens_cpu,
|
num_computed_tokens_cpu,
|
||||||
seq_lens=common_attn_metadata.seq_lens)
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
|
max_seq_len=0)
|
||||||
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc[
|
query_start_loc = common_attn_metadata.query_start_loc[
|
||||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
1:1 + num_rejected_tokens_gpu.shape[0]]
|
||||||
|
|||||||
@@ -1045,7 +1045,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
prefill_context_parallel_metadata=long_seq_metadata,
|
prefill_context_parallel_metadata=long_seq_metadata,
|
||||||
)
|
max_seq_len=0)
|
||||||
|
|
||||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||||
# For pcp + spec decode, we flatten block_table
|
# For pcp + spec decode, we flatten block_table
|
||||||
@@ -1874,7 +1874,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
prefill_context_parallel_metadata=long_seq_metadata,
|
prefill_context_parallel_metadata=long_seq_metadata,
|
||||||
)
|
max_seq_len=0)
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
common_attn_metadata.block_table_tensor = \
|
common_attn_metadata.block_table_tensor = \
|
||||||
block_table_tensor[:num_reqs * self.decode_threshold]
|
block_table_tensor[:num_reqs * self.decode_threshold]
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ def build_attn_metadata(
|
|||||||
"""Build attention metadata for Ascend NPUs."""
|
"""Build attention metadata for Ascend NPUs."""
|
||||||
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
||||||
max_query_len = int(query_start_loc_cpu.max())
|
max_query_len = int(query_start_loc_cpu.max())
|
||||||
|
max_seq_len = int(seq_lens_cpu.max())
|
||||||
|
|
||||||
attn_metadata: dict[str, Any] = {}
|
attn_metadata: dict[str, Any] = {}
|
||||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||||
@@ -80,7 +81,7 @@ def build_attn_metadata(
|
|||||||
graph_pad_size=graph_pad_size,
|
graph_pad_size=graph_pad_size,
|
||||||
num_input_tokens=num_input_tokens,
|
num_input_tokens=num_input_tokens,
|
||||||
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
|
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
|
||||||
)
|
max_seq_len=max_seq_len)
|
||||||
|
|
||||||
attn_metadata_builder = attn_metadata_builders[i]
|
attn_metadata_builder = attn_metadata_builders[i]
|
||||||
metadata = attn_metadata_builder.build(
|
metadata = attn_metadata_builder.build(
|
||||||
|
|||||||
Reference in New Issue
Block a user