[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)

### What this PR does / why we need it?

A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:

Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;

### Does this PR introduce _any_ user-facing change?

No change at user-facing.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
57c22e57f9

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-07-28 14:06:20 +08:00
committed by GitHub
parent e3a2443c3a
commit ba3dfbd59e
22 changed files with 629 additions and 347 deletions

View File

@@ -127,8 +127,6 @@ class AscendTorchairMetadata:
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
# max value of number of tokens across dp group
max_num_tokens_across_dp: int = 0
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -139,7 +137,7 @@ class AscendTorchairMetadata:
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
with_prefill_across_dp: bool = False
decode: Optional[AscendDecodeMetadata] = None
@@ -178,8 +176,9 @@ class AscendAttentionTorchairMetadataBuilder:
return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int,
num_actual_tokens: int) -> AscendTorchairMetadata:
def build_torchair_graph_dummy(
self, num_reqs: int,
num_actual_tokens: int) -> AscendTorchairMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
@@ -214,7 +213,6 @@ class AscendAttentionTorchairMetadataBuilder:
seq_lens=seq_lens,
slot_mapping=slot_mapping,
attn_state=AscendAttentionState.DecodeOnly,
max_num_tokens_across_dp=num_reqs,
decode=decode_metadata)
return attn_metadata
@@ -222,9 +220,7 @@ class AscendAttentionTorchairMetadataBuilder:
num_reqs,
num_actual_tokens,
max_query_len,
graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
graph_pad_size: int = -1):
device = self.runner.device
@@ -263,7 +259,6 @@ class AscendAttentionTorchairMetadataBuilder:
pad_value = 1
padded_seq_lens = seq_lens.tolist() + [pad_value
] * graph_pad_size
max_num_tokens_across_dp = len(padded_seq_lens)
seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32))
@@ -303,9 +298,7 @@ class AscendAttentionTorchairMetadataBuilder:
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp)
attn_state=attn_state)
return attn_metadata