[Feature] model_runner refactor (#4764)
### What this PR does / why we need it?
refactor npu_modelrunner, we should be close to gpu_modelrunner
### Does this PR introduce _any_ user-facing change?
NO
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
@@ -256,11 +256,12 @@ class MtpProposer(Proposer):
|
||||
self.runner.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.runner.
|
||||
query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.runner.seq_lens_cpu,
|
||||
seq_lens=self.runner.seq_lens[:num_reqs],
|
||||
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs +
|
||||
1],
|
||||
query_start_loc_cpu=self.runner.query_start_loc.
|
||||
cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.runner.seq_lens.cpu,
|
||||
seq_lens=self.runner.seq_lens.gpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
num_input_tokens=num_tokens,
|
||||
@@ -268,16 +269,14 @@ class MtpProposer(Proposer):
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor()[:num_reqs],
|
||||
get_device_tensor(),
|
||||
slot_mapping=self.runner.input_batch.block_table[0].
|
||||
slot_mapping,
|
||||
positions=self.runner.positions,
|
||||
slot_mapping.gpu,
|
||||
positions=self.runner.positions.gpu,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
cos=self.runner.cos,
|
||||
sin=self.runner.sin,
|
||||
)
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
@@ -304,7 +303,6 @@ class MtpProposer(Proposer):
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_type=moe_comm_type,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0,
|
||||
@@ -406,7 +404,8 @@ class MtpProposer(Proposer):
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||
target_token_ids = self.runner.input_ids.gpu[:
|
||||
num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
@@ -435,7 +434,7 @@ class MtpProposer(Proposer):
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
target_token_ids = self.runner.input_ids[token_indices]
|
||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
|
||||
@@ -748,7 +747,7 @@ class MtpProposer(Proposer):
|
||||
uniform_decode = False
|
||||
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||||
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||||
if self.use_async_scheduling:
|
||||
# there is synchronization between mtp steps when enabling aclgraph,
|
||||
# disable aclgraph when use async scheduling to avoid the
|
||||
@@ -781,7 +780,6 @@ class MtpProposer(Proposer):
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_type=moe_comm_type,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
|
||||
Reference in New Issue
Block a user