[Feature] model_runner refactor (#4764)

### What this PR does / why we need it?
refactor npu_modelrunner, we should be close to gpu_modelrunner 

### Does this PR introduce _any_ user-facing change?
NO

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
zhenwenqi2024
2025-12-12 17:27:09 +08:00
committed by GitHub
parent 5b12c068f9
commit f708d919f8
10 changed files with 676 additions and 1815 deletions

View File

@@ -256,11 +256,12 @@ class MtpProposer(Proposer):
self.runner.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.runner.
query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.runner.seq_lens_cpu,
seq_lens=self.runner.seq_lens[:num_reqs],
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs +
1],
query_start_loc_cpu=self.runner.query_start_loc.
cpu[:num_reqs + 1],
seq_lens_cpu=self.runner.seq_lens.cpu,
seq_lens=self.runner.seq_lens.gpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
num_input_tokens=num_tokens,
@@ -268,16 +269,14 @@ class MtpProposer(Proposer):
num_computed_tokens_cpu=num_computed_tokens_cpu,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs],
get_device_tensor(),
slot_mapping=self.runner.input_batch.block_table[0].
slot_mapping,
positions=self.runner.positions,
slot_mapping.gpu,
positions=self.runner.positions.gpu,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
cos=self.runner.cos,
sin=self.runner.sin,
)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
@@ -304,7 +303,6 @@ class MtpProposer(Proposer):
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0,
@@ -406,7 +404,8 @@ class MtpProposer(Proposer):
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
@@ -435,7 +434,7 @@ class MtpProposer(Proposer):
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids[token_indices]
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
@@ -748,7 +747,7 @@ class MtpProposer(Proposer):
uniform_decode = False
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
if self.use_async_scheduling:
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
@@ -781,7 +780,6 @@ class MtpProposer(Proposer):
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,