[2/N][refactor] torchair deepseek mla backend refactor (#2459)

### What this PR does / why we need it?
This PR move current unified mla backend to torchair folder and remove
torchair-related code in attention/mla_v1.py (1.3k -> 0.9k).

 
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Running eager mode with mla backend, and torchair mode with code before
[2445](https://github.com/vllm-project/vllm-ascend/pull/2445)


- vLLM version: v0.10.0
- vLLM main:
f571ff8eb6

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2025-08-21 14:02:30 +08:00
committed by GitHub
parent 67a222c383
commit 0ca3f48c90
7 changed files with 2192 additions and 747 deletions

View File

@@ -92,6 +92,7 @@ from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
maybe_converting_weight_acl_format)
@@ -624,7 +625,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata]]:
AscendTorchairMetadata, AscendMLATorchairMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -736,7 +737,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc[num_reqs + 1:].fill_(-1)
attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata]] = {}
AscendTorchairMetadata,
AscendMLATorchairMetadata]] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@@ -1000,8 +1002,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata], torch.Tensor, np.ndarray, int,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata,
AscendMLATorchairMetadata], torch.Tensor, np.ndarray, int,
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@@ -1466,7 +1468,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata],
AscendTorchairMetadata,
AscendMLATorchairMetadata],
aux_hidden_states: torch.Tensor = None,
) -> Optional[list[list[int]]]:
if not self.use_spec_decode:
@@ -2540,7 +2543,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata],
AscendTorchairMetadata,
AscendMLATorchairMetadata],
):
assert isinstance(self.drafter, MtpProposer)
next_token_ids: list[int] = []