[MTP][V1] Adapt mtp with graph mode in v1. (#1023)
Adapts deepseek mtp with torch air graph mode in v1. --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -203,8 +203,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_spec_decode = False
|
||||
self.spec_attn_mask = None
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
2048,
|
||||
dtype=torch.bool),
|
||||
diagonal=1).to("npu")
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
@@ -779,10 +784,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
max_num_scheduled_tokens = 0
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens[i] = num_tokens
|
||||
num_valid_tokens[i] = num_tokens - \
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
|
||||
@@ -838,11 +846,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# splitfuse
|
||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -873,7 +886,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
with_prefill = attn_state != AscendAttentionState.DecodeOnly
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
if self.dp_size > 1:
|
||||
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
||||
@@ -883,14 +898,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Add graph_pad_size here
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
|
||||
and not with_prefill):
|
||||
batch_size = len(seq_lens)
|
||||
if self.dp_size > 1:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
max_num_tokens)
|
||||
else:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
batch_size)
|
||||
graph_pad_size = padded_batch_size - batch_size
|
||||
total_num_scheduled_tokens)
|
||||
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
|
||||
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
|
||||
Reference in New Issue
Block a user