[Misc] Clean up uesless code in attention (#1933)
Before do attention module refactor, we can do some code cleanup to make
the next step easier.
What this PR does:
1. remove uesless `common_prefix_len` for attention builder
2. remove uesless `is_only_prefill` and `num_input_tokens` in attention
metadata.
3. remove `CommonAttentionMetadata` and ues `query_start_loc` instead,
`CommonAttentionMetadata` is over designed and uesless
4. update the attention backend input parameters to keep the same as
vLLM.
5. Rename attention name to the same style with `ASCEND` prefix
- vLLM version: v0.9.2
- vLLM main:
107111a859
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -96,7 +96,6 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
num_reqs = 2
|
num_reqs = 2
|
||||||
num_actual_tokens = 10
|
num_actual_tokens = 10
|
||||||
max_query_len = 5
|
max_query_len = 5
|
||||||
common_prefix_len = 1
|
|
||||||
|
|
||||||
self.mock_runner.input_batch.block_table = [MagicMock()]
|
self.mock_runner.input_batch.block_table = [MagicMock()]
|
||||||
self.mock_runner.input_batch.block_table[
|
self.mock_runner.input_batch.block_table[
|
||||||
@@ -114,8 +113,11 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
||||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||||
|
|
||||||
self.builder.build(num_reqs, num_actual_tokens, max_query_len,
|
self.builder.build(
|
||||||
common_prefix_len)
|
num_reqs,
|
||||||
|
num_actual_tokens,
|
||||||
|
max_query_len,
|
||||||
|
)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||||
@patch('torch_npu.npu_format_cast')
|
@patch('torch_npu.npu_format_cast')
|
||||||
@@ -148,7 +150,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
||||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||||
|
|
||||||
self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0)
|
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||||
@@ -169,7 +171,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
|
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
|
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
|
||||||
|
|
||||||
self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0)
|
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
|
||||||
|
|
||||||
|
|
||||||
class TestAscendAttentionBackendImpl(TestBase):
|
class TestAscendAttentionBackendImpl(TestBase):
|
||||||
@@ -201,7 +203,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None,
|
sliding_window=None,
|
||||||
kv_cache_dtype="float16",
|
kv_cache_dtype="float16",
|
||||||
attn_type=self.attention_type.DECODER)
|
logits_soft_cap=None,
|
||||||
|
attn_type=self.attention_type.DECODER,
|
||||||
|
kv_sharing_target_layer_name=None)
|
||||||
|
|
||||||
self.impl_192 = AscendAttentionBackendImpl(
|
self.impl_192 = AscendAttentionBackendImpl(
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
@@ -211,16 +215,21 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None,
|
sliding_window=None,
|
||||||
kv_cache_dtype="float16",
|
kv_cache_dtype="float16",
|
||||||
attn_type=self.attention_type.DECODER)
|
logits_soft_cap=None,
|
||||||
|
attn_type=self.attention_type.DECODER,
|
||||||
|
kv_sharing_target_layer_name=None)
|
||||||
|
|
||||||
self.impl_error = AscendAttentionBackendImpl(num_heads=8,
|
self.impl_error = AscendAttentionBackendImpl(
|
||||||
head_size=192,
|
num_heads=8,
|
||||||
scale=1.0,
|
head_size=192,
|
||||||
num_kv_heads=8,
|
scale=1.0,
|
||||||
alibi_slopes=None,
|
num_kv_heads=8,
|
||||||
sliding_window=None,
|
alibi_slopes=None,
|
||||||
kv_cache_dtype="float16",
|
sliding_window=None,
|
||||||
attn_type=None)
|
kv_cache_dtype="float16",
|
||||||
|
logits_soft_cap=None,
|
||||||
|
attn_type=None,
|
||||||
|
kv_sharing_target_layer_name=None)
|
||||||
|
|
||||||
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
|
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
|
||||||
def test_forward_trace_flag_true(self, mock_unified_attention):
|
def test_forward_trace_flag_true(self, mock_unified_attention):
|
||||||
|
|||||||
@@ -130,10 +130,8 @@ class AscendMetadata:
|
|||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
query_lens: torch.Tensor
|
query_lens: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
# max value of number of tokens across dp group
|
# max value of number of tokens across dp group
|
||||||
max_num_tokens_across_dp: int = 0
|
max_num_tokens_across_dp: int = 0
|
||||||
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||||
@@ -141,18 +139,9 @@ class AscendMetadata:
|
|||||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||||
# in block 0, and 1st slot in block 1, respectively.
|
# in block 0, and 1st slot in block 1, respectively.
|
||||||
slot_mapping: torch.Tensor = None
|
slot_mapping: torch.Tensor = None
|
||||||
# TODO: Indicates whether there are only prefill requests.
|
|
||||||
# FlashAttention can be used when there are only prefill requests.
|
|
||||||
# FlashAttention has better performance than PageAtttention,
|
|
||||||
# but it does not support decode requests.
|
|
||||||
is_only_prefill: bool = False
|
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For logging.
|
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
||||||
|
|
||||||
with_prefill_across_dp: bool = False
|
with_prefill_across_dp: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -169,7 +158,6 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_reqs,
|
num_reqs,
|
||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
common_prefix_len,
|
|
||||||
max_num_tokens_across_dp: int = 0,
|
max_num_tokens_across_dp: int = 0,
|
||||||
with_prefill_across_dp: bool = False):
|
with_prefill_across_dp: bool = False):
|
||||||
|
|
||||||
@@ -224,10 +212,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
alibi_slopes: Optional[List[float]],
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
use_irope: bool = False,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class AscendAttentionTorchairBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "ASCEND"
|
return "ASCEND_TORCHAIR"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
|
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
|
||||||
@@ -129,10 +129,8 @@ class AscendTorchairMetadata:
|
|||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
query_lens: torch.Tensor
|
query_lens: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
# max value of number of tokens across dp group
|
# max value of number of tokens across dp group
|
||||||
max_num_tokens_across_dp: int = 0
|
max_num_tokens_across_dp: int = 0
|
||||||
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||||
@@ -140,20 +138,10 @@ class AscendTorchairMetadata:
|
|||||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||||
# in block 0, and 1st slot in block 1, respectively.
|
# in block 0, and 1st slot in block 1, respectively.
|
||||||
slot_mapping: torch.Tensor = None
|
slot_mapping: torch.Tensor = None
|
||||||
# TODO: Indicates whether there are only prefill requests.
|
|
||||||
# FlashAttention can be used when there are only prefill requests.
|
|
||||||
# FlashAttention has better performance than PageAtttention,
|
|
||||||
# but it does not support decode requests.
|
|
||||||
is_only_prefill: bool = False
|
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For logging.
|
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
||||||
|
|
||||||
with_prefill_across_dp: bool = False
|
with_prefill_across_dp: bool = False
|
||||||
|
|
||||||
decode: Optional[AscendDecodeMetadata] = None
|
decode: Optional[AscendDecodeMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -236,7 +224,6 @@ class AscendAttentionTorchairMetadataBuilder:
|
|||||||
num_reqs,
|
num_reqs,
|
||||||
num_actual_tokens,
|
num_actual_tokens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
common_prefix_len,
|
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
max_num_tokens_across_dp: int = 0,
|
max_num_tokens_across_dp: int = 0,
|
||||||
with_prefill_across_dp: bool = False):
|
with_prefill_across_dp: bool = False):
|
||||||
@@ -335,10 +322,10 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
|||||||
alibi_slopes: Optional[List[float]],
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
use_irope: bool = False,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|||||||
@@ -31,27 +31,13 @@ if TYPE_CHECKING:
|
|||||||
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CommonAttentionMetadata:
|
|
||||||
"""
|
|
||||||
Attention metadata attributes that can be shared by layers in different KV
|
|
||||||
cache groups and thus having different block table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
query_start_loc: torch.Tensor
|
|
||||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
"""(batch_size,), the length of each request including both computed tokens
|
|
||||||
and newly scheduled tokens"""
|
|
||||||
|
|
||||||
|
|
||||||
class AscendMLABackend(AttentionBackend):
|
class AscendMLABackend(AttentionBackend):
|
||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "VLLM_ASCEND_MLA"
|
return "ASCEND_MLA"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
@@ -368,11 +354,10 @@ class AscendMLAMetadataBuilder:
|
|||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
num_actual_tokens: int,
|
num_actual_tokens: int,
|
||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
common_prefix_len: Optional[int] = None,
|
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
max_num_tokens_across_dp: int = 0,
|
max_num_tokens_across_dp: int = 0,
|
||||||
with_prefill_across_dp: bool = False,
|
with_prefill_across_dp: bool = False,
|
||||||
|
query_start_loc: torch.Tensor = None,
|
||||||
) -> AscendMLAMetadata:
|
) -> AscendMLAMetadata:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
@@ -394,7 +379,6 @@ class AscendMLAMetadataBuilder:
|
|||||||
seq_lens = seq_lens_cpu
|
seq_lens = seq_lens_cpu
|
||||||
max_query_len = query_lens.max().item()
|
max_query_len = query_lens.max().item()
|
||||||
max_seq_lens = seq_lens.max().item()
|
max_seq_lens = seq_lens.max().item()
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
chunked_context_metadata = None
|
chunked_context_metadata = None
|
||||||
@@ -403,7 +387,6 @@ class AscendMLAMetadataBuilder:
|
|||||||
tokens_start = self._num_decode_tokens
|
tokens_start = self._num_decode_tokens
|
||||||
max_query_len = query_lens[tokens_start:].max().item()
|
max_query_len = query_lens[tokens_start:].max().item()
|
||||||
max_seq_lens = seq_lens[tokens_start:].max().item()
|
max_seq_lens = seq_lens[tokens_start:].max().item()
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
|
||||||
prefill_query_start_loc = query_start_loc[
|
prefill_query_start_loc = query_start_loc[
|
||||||
reqs_start:] - query_start_loc[reqs_start]
|
reqs_start:] - query_start_loc[reqs_start]
|
||||||
|
|
||||||
@@ -539,7 +522,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|||||||
@@ -132,7 +132,6 @@ class EagleProposer:
|
|||||||
num_reqs=batch_size,
|
num_reqs=batch_size,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
common_prefix_len=0,
|
|
||||||
)
|
)
|
||||||
if self.use_cuda_graph and \
|
if self.use_cuda_graph and \
|
||||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
|
|||||||
@@ -75,8 +75,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|||||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||||
AscendMetadata)
|
AscendMetadata)
|
||||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||||
from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata,
|
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||||
CommonAttentionMetadata)
|
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||||
@@ -694,15 +693,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
|
|
||||||
# Prepare for cascade attention if enabled & beneficial.
|
|
||||||
common_prefix_len = 0
|
|
||||||
|
|
||||||
attn_metadata_i = self.attn_metadata_builder.build(
|
attn_metadata_i = self.attn_metadata_builder.build(
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
common_prefix_len=common_prefix_len,
|
|
||||||
)
|
)
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
@@ -1049,27 +1043,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
extra_builder_kwargs[
|
||||||
seq_lens = self.seq_lens[:num_reqs]
|
"query_start_loc"] = self.query_start_loc[:num_reqs + 1]
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
|
||||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
|
||||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
common_prefix_len=None,
|
|
||||||
**extra_builder_kwargs,
|
**extra_builder_kwargs,
|
||||||
)
|
)
|
||||||
|
attn_metadata.num_input_tokens = num_input_tokens
|
||||||
else:
|
else:
|
||||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
common_prefix_len=None,
|
|
||||||
**extra_builder_kwargs,
|
**extra_builder_kwargs,
|
||||||
)
|
)
|
||||||
attn_metadata.num_input_tokens = num_input_tokens
|
|
||||||
|
|
||||||
# Prepare input_ids
|
# Prepare input_ids
|
||||||
token_indices = (positions_np +
|
token_indices = (positions_np +
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from vllm.model_executor.model_loader.utils import (
|
|||||||
process_weights_after_loading, set_default_torch_dtype)
|
process_weights_after_loading, set_default_torch_dtype)
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
|
||||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||||
|
|
||||||
|
|
||||||
@@ -100,11 +99,6 @@ class MtpProposer:
|
|||||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||||
max_query_len = query_lens.max().item()
|
max_query_len = query_lens.max().item()
|
||||||
|
|
||||||
seq_lens = (target_positions[last_token_indices] + 1)
|
|
||||||
|
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
|
||||||
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
|
|
||||||
|
|
||||||
# FIXME: reorder_batch() needs to be called before build()
|
# FIXME: reorder_batch() needs to be called before build()
|
||||||
# because fields of attn_metadata_builder needs to be updated.
|
# because fields of attn_metadata_builder needs to be updated.
|
||||||
# However, currently reorder_batch() takes input_batch and
|
# However, currently reorder_batch() takes input_batch and
|
||||||
@@ -120,8 +114,7 @@ class MtpProposer:
|
|||||||
num_reqs=batch_size,
|
num_reqs=batch_size,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
common_prefix_len=0,
|
query_start_loc=cu_num_tokens,
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
|
|||||||
Reference in New Issue
Block a user