feat: add mtp ut and fix some bugs (#2453)
### What this PR does / why we need it?
Fix mtp mode ut
### Does this PR introduce _any_ user-facing change?
Nothing
### How was this patch tested?
This can be tested in the same way as a unit test.
- vLLM version: v0.10.0
- vLLM main:
53415653ff
Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
@@ -374,18 +374,12 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decode_tokens]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decode_tokens, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
# TODO(xyx): whether this block is necessary without torchair
|
||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
||||
batch_size = slot_mapping.size(0)
|
||||
if actual_seq_lengths_q[-1] != batch_size \
|
||||
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
actual_seq_lengths_q[-1] = batch_size
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
@@ -215,4 +215,4 @@ class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
return hidden_states
|
||||
@@ -1178,7 +1178,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
self.moe = FusedMoEConfig.make(
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
@@ -1188,8 +1188,10 @@ class AscendFusedMoE(FusedMoE):
|
||||
in_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.moe_config = moe
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe)
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return AscendUnquantizedFusedMoEMethod(layer.moe)
|
||||
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return AscendFusedMoEMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
|
||||
@@ -492,17 +492,17 @@ class AscendMLATorchairMetadataBuilder:
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
if num_decodes > 0:
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decode_tokens]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decode_tokens, ...]
|
||||
num_token_pad_size = 0
|
||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
num_reqs_pad_size = 0
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - num_decode_tokens
|
||||
@@ -535,13 +535,14 @@ class AscendMLATorchairMetadataBuilder:
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat(
|
||||
[input_positions, position_padding])
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist(
|
||||
) + common_attn_metadata.actual_seq_lengths_q[
|
||||
num_reqs:num_reqs + num_reqs_pad_size]
|
||||
actual_seq_lengths_q = (
|
||||
actual_seq_lengths_q + common_attn_metadata.
|
||||
actual_seq_lengths_q[num_reqs:num_reqs +
|
||||
num_reqs_pad_size])
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
||||
batch_size = slot_mapping.size(0)
|
||||
batch_size = num_decode_tokens + num_token_pad_size
|
||||
if actual_seq_lengths_q[-1] != batch_size \
|
||||
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
actual_seq_lengths_q[-1] = batch_size
|
||||
|
||||
@@ -190,11 +190,6 @@ class MtpProposer:
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
attn_metadata.prefill.query_lens = query_lens.cpu()
|
||||
attn_metadata.prefill.input_positions = target_positions
|
||||
attn_metadata.prefill.seq_lens = seq_lens
|
||||
|
||||
if not self.torchair_graph_enabled:
|
||||
# torch mode need to update num_tokens_across_dp
|
||||
# TODO: adapt enable_dbo later
|
||||
@@ -213,6 +208,7 @@ class MtpProposer:
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
@@ -315,6 +311,7 @@ class MtpProposer:
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0):
|
||||
if is_running_torchair:
|
||||
|
||||
@@ -47,9 +47,14 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (init_ascend_soc_version,
|
||||
register_ascend_customop, sleep_mode_enabled,
|
||||
try_register_lib)
|
||||
try_register_lib, vllm_version_is)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if not vllm_version_is("0.10.1.1"):
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
else:
|
||||
DraftTokenIds = None
|
||||
|
||||
|
||||
class NPUWorker(WorkerBase):
|
||||
|
||||
@@ -343,3 +348,6 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
return self.model_runner.take_draft_token_ids()
|
||||
Reference in New Issue
Block a user