[FEATURE][MTP] Support MTP > 1 (#2708)
### What this PR does / why we need it?
[RFC:Support MTP > 1 for
DeepSeek](https://github.com/vllm-project/vllm-ascend/issues/2745)
- [x] dp1 tp16
- [x] dp4 tp4
- [x] dp2 tp 8
- [x] torchair graph
- vLLM version: v0.10.1.1
- vLLM main:
c9f7081f9c
Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -20,9 +20,10 @@ def model_name():
|
||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
|
||||
def test_mtp_correctness(
|
||||
def mtp_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -50,7 +51,7 @@ def test_mtp_correctness(
|
||||
enable_expert_parallel=True,
|
||||
speculative_config={
|
||||
"method": "deepseek_mtp",
|
||||
"num_speculative_tokens": 1,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
enforce_eager=True,
|
||||
max_model_len=2000,
|
||||
@@ -74,3 +75,18 @@ def test_mtp_correctness(
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
|
||||
|
||||
def test_mtp1_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 1)
|
||||
|
||||
|
||||
def test_mtp2_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 2)
|
||||
|
||||
@@ -841,7 +841,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
assert num_tokens % self.spec_token_num == 0
|
||||
input_layout = "TND"
|
||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||
|
||||
@@ -27,6 +27,8 @@ from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
||||
TorchairCommonAttentionMetadata)
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class MtpProposer(Proposer):
|
||||
|
||||
@@ -40,6 +42,7 @@ class MtpProposer(Proposer):
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.runner = runner
|
||||
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
# persistent buffers for graph
|
||||
self.input_ids = torch.zeros(self.runner.max_num_tokens,
|
||||
@@ -57,6 +60,12 @@ class MtpProposer(Proposer):
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.torchair_graph_enabled = get_ascend_config(
|
||||
).torchair_graph_config.enabled
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||
1,
|
||||
device=self.runner.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
@@ -125,43 +134,47 @@ class MtpProposer(Proposer):
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0):
|
||||
if is_running_torchair:
|
||||
assert attn_metadata is not None
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(previous_hidden_states)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
torchair_compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
previous_hidden_states=previous_hidden_states,
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
attn_metadata=attn_metadata,
|
||||
kv_caches=self.runner.kv_caches[-1:],
|
||||
spec_step_idx=0)
|
||||
else:
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
previous_hidden_states=previous_hidden_states)
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0):
|
||||
if is_running_torchair:
|
||||
assert attn_metadata is not None
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(previous_hidden_states)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.input_positions)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
torchair_compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
previous_hidden_states=previous_hidden_states,
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
attn_metadata=attn_metadata,
|
||||
kv_caches=self.runner.kv_caches[-1:],
|
||||
spec_step_idx=0)
|
||||
else:
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
previous_hidden_states=previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
def generate_token_ids(self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
@@ -385,57 +398,142 @@ class MtpProposer(Proposer):
|
||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||
with_prefill = self.runner.with_prefill
|
||||
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||
if is_running_torchair:
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_input_tokens)
|
||||
hidden_states = torchair_compiled_model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
previous_hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
spec_step_idx=0,
|
||||
**model_kwargs)
|
||||
for step in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||
if is_running_torchair:
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_input_tokens)
|
||||
hidden_states = torchair_compiled_model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
previous_hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
spec_step_idx=0,
|
||||
**model_kwargs)
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
previous_hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
kv_caches=self.runner.kv_caches[-1:])
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
if not self.runner.with_prefill:
|
||||
max_num_reqs_across_dp = num_input_tokens
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
previous_hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
kv_caches=self.runner.kv_caches[-1:])
|
||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||
last_token_indices = nn.functional.pad(
|
||||
last_token_indices,
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
if not self.runner.with_prefill:
|
||||
max_num_reqs_across_dp = num_input_tokens
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
if step == 0:
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
else:
|
||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||
last_token_indices = nn.functional.pad(
|
||||
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
# prepare next mtp inputs
|
||||
# mtp>1: prefill skip or decode skip last loop
|
||||
if with_prefill and self.torchair_graph_enabled:
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||
break
|
||||
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
if step == 0:
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
slot_mapping = attn_metadata.slot_mapping[last_token_indices]
|
||||
attn_metadata.slot_mapping.fill_(-1)
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
if attn_metadata.num_decode_tokens != 0:
|
||||
attn_metadata.num_decode_tokens = batch_size
|
||||
if is_running_torchair:
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.query_lens = [1] * batch_size
|
||||
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.seq_lens[:batch_size] += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
attn_metadata.seq_lens.device, non_blocking=True)
|
||||
attn_metadata.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len_cpu, 1)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping += 1
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
attn_metadata.slot_mapping[:batch_size] = slot_mapping
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
attn_metadata.prefill.seq_lens = attn_metadata.seq_lens
|
||||
attn_metadata.prefill.context_lens = attn_metadata.seq_lens
|
||||
attn_metadata.prefill.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata.prefill.max_seq_lens += 1
|
||||
attn_metadata.prefill.max_seq_lens = min(
|
||||
attn_metadata.prefill.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
if attn_metadata.decode is not None:
|
||||
attn_metadata.decode.seq_lens = attn_metadata.seq_lens
|
||||
attn_metadata.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata.decode.max_seq_lens += 1
|
||||
attn_metadata.decode.max_seq_lens = min(
|
||||
attn_metadata.decode.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
|
||||
# mtp>1: [batch_size, k]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
||||
@@ -511,4 +609,4 @@ class MtpProposer(Proposer):
|
||||
|
||||
global_indices_flat = global_indices[mask]
|
||||
values_flat = values[mask]
|
||||
out_ptr[global_indices_flat] = values_flat
|
||||
out_ptr[global_indices_flat] = values_flat
|
||||
|
||||
@@ -1020,7 +1020,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
assert num_tokens % self.spec_token_num == 0
|
||||
input_layout = "TND"
|
||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
# isort: skip_file
|
||||
|
||||
import math
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
@@ -427,6 +428,11 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = (graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
|
||||
# Both adapter multi-dp and FIA operator
|
||||
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
|
||||
cur_graph_batch_size = (tp_size * graph_batch_size) \
|
||||
// math.gcd(tp_size, graph_batch_size)
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
|
||||
Reference in New Issue
Block a user