[1/N][UT][v1 MTP] add basic v1 mtp features (#890)

### What this PR does / why we need it?
add basic v1 mtp features
please merge it after
https://github.com/vllm-project/vllm-ascend/pull/874 and
https://github.com/vllm-project/vllm-ascend/pull/844.

### Does this PR introduce _any_ user-facing change?
now, we supported basic v1 mtp, only supported tp only、eager mode and
k=1
we will continue to expand more scenarios.

### How was this patch tested?
local tested

Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
This commit is contained in:
XWFAlone
2025-05-30 08:59:58 +08:00
committed by GitHub
parent 5903547d09
commit 3442fbdb23
7 changed files with 477 additions and 13 deletions

View File

@@ -59,8 +59,10 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
@@ -201,6 +203,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
elif self.speculative_config.method == "eagle":
self.drafter = EagleProposer(self.vllm_config,
self.device) # type: ignore
elif self.speculative_config.method == 'deepseek_mtp':
self.drafter = MtpProposer(self.vllm_config, self)
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
@@ -216,6 +220,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -591,18 +601,43 @@ class NPUModelRunner(LoRAModelRunnerMixin):
extra_builder_kwargs = {}
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
# Add graph_pad_size here
if self.enable_torchair_graph_mode:
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=None,
**extra_builder_kwargs,
)
if self.vllm_config.model_config.use_mla:
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_attn_metadata=common_attn_metadata,
common_prefix_len=None,
**extra_builder_kwargs,
)
else:
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=None,
**extra_builder_kwargs,
)
attn_metadata.num_input_tokens = num_input_tokens
# Prepare input_ids
@@ -836,6 +871,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
raise NotImplementedError(
"eagle method for spec decode doesn't work on vllm-ascend currently"
)
elif self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer)
spec_token_ids = self._generate_mtp_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, attn_metadata)
return spec_token_ids
@torch.inference_mode()
@@ -1126,7 +1167,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.model = get_model(vllm_config=self.vllm_config)
if hasattr(self, "drafter"):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
self.drafter.load_model()
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
@@ -1333,3 +1374,73 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def _generate_mtp_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata,
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: SpecDecodeMetadata,
):
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids