[2/N][refactor] torchair deepseek mla backend refactor (#2459)
### What this PR does / why we need it?
This PR move current unified mla backend to torchair folder and remove
torchair-related code in attention/mla_v1.py (1.3k -> 0.9k).
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Running eager mode with mla backend, and torchair mode with code before
[2445](https://github.com/vllm-project/vllm-ascend/pull/2445)
- vLLM version: v0.10.0
- vLLM main:
f571ff8eb6
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -92,6 +92,7 @@ from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
|
||||
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
ProfileExecuteDuration, is_310p,
|
||||
maybe_converting_weight_acl_format)
|
||||
@@ -624,7 +625,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> dict[str, Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata]]:
|
||||
AscendTorchairMetadata, AscendMLATorchairMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@@ -736,7 +737,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata]] = {}
|
||||
AscendTorchairMetadata,
|
||||
AscendMLATorchairMetadata]] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@@ -1000,8 +1002,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata], torch.Tensor, np.ndarray, int,
|
||||
) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata,
|
||||
AscendMLATorchairMetadata], torch.Tensor, np.ndarray, int,
|
||||
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
|
||||
Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
@@ -1466,7 +1468,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata],
|
||||
AscendTorchairMetadata,
|
||||
AscendMLATorchairMetadata],
|
||||
aux_hidden_states: torch.Tensor = None,
|
||||
) -> Optional[list[list[int]]]:
|
||||
if not self.use_spec_decode:
|
||||
@@ -2540,7 +2543,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata],
|
||||
AscendTorchairMetadata,
|
||||
AscendMLATorchairMetadata],
|
||||
):
|
||||
assert isinstance(self.drafter, MtpProposer)
|
||||
next_token_ids: list[int] = []
|
||||
|
||||
Reference in New Issue
Block a user