[FEATURE][MTP] Support MTP > 1 (#2708)
### What this PR does / why we need it?
[RFC:Support MTP > 1 for
DeepSeek](https://github.com/vllm-project/vllm-ascend/issues/2745)
- [x] dp1 tp16
- [x] dp4 tp4
- [x] dp2 tp 8
- [x] torchair graph
- vLLM version: v0.10.1.1
- vLLM main:
c9f7081f9c
Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -20,9 +20,10 @@ def model_name():
|
|||||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||||
|
|
||||||
|
|
||||||
def test_mtp_correctness(
|
def mtp_correctness(
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
num_speculative_tokens: int,
|
||||||
):
|
):
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -50,7 +51,7 @@ def test_mtp_correctness(
|
|||||||
enable_expert_parallel=True,
|
enable_expert_parallel=True,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"method": "deepseek_mtp",
|
"method": "deepseek_mtp",
|
||||||
"num_speculative_tokens": 1,
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
},
|
},
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
max_model_len=2000,
|
max_model_len=2000,
|
||||||
@@ -74,3 +75,18 @@ def test_mtp_correctness(
|
|||||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches > int(0.66 * len(ref_outputs))
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
|
del spec_llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp1_correctness(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp2_correctness(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 2)
|
||||||
|
|||||||
@@ -841,7 +841,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
input_layout = "BNSD"
|
input_layout = "BNSD"
|
||||||
|
|
||||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||||
assert num_tokens % self.spec_token_num == 0
|
|
||||||
input_layout = "TND"
|
input_layout = "TND"
|
||||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
|||||||
TorchairCommonAttentionMetadata)
|
TorchairCommonAttentionMetadata)
|
||||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||||
|
|
||||||
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
class MtpProposer(Proposer):
|
class MtpProposer(Proposer):
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ class MtpProposer(Proposer):
|
|||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.device = device
|
self.device = device
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
|
||||||
# persistent buffers for graph
|
# persistent buffers for graph
|
||||||
self.input_ids = torch.zeros(self.runner.max_num_tokens,
|
self.input_ids = torch.zeros(self.runner.max_num_tokens,
|
||||||
@@ -57,6 +60,12 @@ class MtpProposer(Proposer):
|
|||||||
self.torchair_compiled_models = {} # type: ignore
|
self.torchair_compiled_models = {} # type: ignore
|
||||||
self.torchair_graph_enabled = get_ascend_config(
|
self.torchair_graph_enabled = get_ascend_config(
|
||||||
).torchair_graph_config.enabled
|
).torchair_graph_config.enabled
|
||||||
|
# We need +1 here because the arange is used to set query_start_loc,
|
||||||
|
# which has one more element than batch_size.
|
||||||
|
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||||
|
1,
|
||||||
|
device=self.runner.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
def load_model(self, model) -> None:
|
def load_model(self, model) -> None:
|
||||||
loader = get_model_loader(self.vllm_config.load_config)
|
loader = get_model_loader(self.vllm_config.load_config)
|
||||||
@@ -125,6 +134,7 @@ class MtpProposer(Proposer):
|
|||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids[:num_tokens]
|
||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||||
|
for _ in range(self.num_speculative_tokens):
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
@@ -140,7 +150,8 @@ class MtpProposer(Proposer):
|
|||||||
torch._dynamo.mark_static(positions)
|
torch._dynamo.mark_static(positions)
|
||||||
torch._dynamo.mark_static(previous_hidden_states)
|
torch._dynamo.mark_static(previous_hidden_states)
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
torch._dynamo.mark_static(
|
||||||
|
attn_metadata.decode.input_positions)
|
||||||
if hasattr(attn_metadata.decode, "sin"):
|
if hasattr(attn_metadata.decode, "sin"):
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||||
@@ -162,6 +173,8 @@ class MtpProposer(Proposer):
|
|||||||
self.model(input_ids=input_ids,
|
self.model(input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
previous_hidden_states=previous_hidden_states)
|
previous_hidden_states=previous_hidden_states)
|
||||||
|
if with_prefill:
|
||||||
|
break
|
||||||
|
|
||||||
def generate_token_ids(self,
|
def generate_token_ids(self,
|
||||||
valid_sampled_token_ids: list[list[int]],
|
valid_sampled_token_ids: list[list[int]],
|
||||||
@@ -385,6 +398,7 @@ class MtpProposer(Proposer):
|
|||||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||||
with_prefill = self.runner.with_prefill
|
with_prefill = self.runner.with_prefill
|
||||||
|
|
||||||
|
for step in range(self.num_speculative_tokens):
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
@@ -426,7 +440,8 @@ class MtpProposer(Proposer):
|
|||||||
else:
|
else:
|
||||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||||
last_token_indices = nn.functional.pad(
|
last_token_indices = nn.functional.pad(
|
||||||
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
|
last_token_indices,
|
||||||
|
(0, max_num_reqs_across_dp - num_indices))
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[last_token_indices]
|
sample_hidden_states = hidden_states[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
@@ -434,9 +449,92 @@ class MtpProposer(Proposer):
|
|||||||
logits = logits[:num_indices]
|
logits = logits[:num_indices]
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
if self.num_speculative_tokens == 1:
|
||||||
# [batch_size, 1]
|
# [batch_size, 1]
|
||||||
return draft_token_ids.view(-1, 1)
|
return draft_token_ids.view(-1, 1)
|
||||||
|
|
||||||
|
if step == 0:
|
||||||
|
draft_token_ids_list = [draft_token_ids]
|
||||||
|
else:
|
||||||
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
|
# prepare next mtp inputs
|
||||||
|
# mtp>1: prefill skip or decode skip last loop
|
||||||
|
if with_prefill and self.torchair_graph_enabled:
|
||||||
|
for _ in range(self.num_speculative_tokens - 1):
|
||||||
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||||
|
break
|
||||||
|
|
||||||
|
if step == 0:
|
||||||
|
positions = target_positions[last_token_indices]
|
||||||
|
hidden_states = hidden_states[last_token_indices]
|
||||||
|
slot_mapping = attn_metadata.slot_mapping[last_token_indices]
|
||||||
|
attn_metadata.slot_mapping.fill_(-1)
|
||||||
|
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||||
|
last_token_indices = self.arange[:batch_size]
|
||||||
|
if attn_metadata.num_decode_tokens != 0:
|
||||||
|
attn_metadata.num_decode_tokens = batch_size
|
||||||
|
if is_running_torchair:
|
||||||
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
|
attn_metadata.query_lens = [1] * batch_size
|
||||||
|
|
||||||
|
input_ids = draft_token_ids_list[-1].int()
|
||||||
|
positions += 1
|
||||||
|
|
||||||
|
# NOTE(woosuk): We should handle the case where the draft model
|
||||||
|
# generates tokens beyond the max model length. Since it is complex
|
||||||
|
# to remove such requests from the batch, we keep them in the batch
|
||||||
|
# but adjust the position ids and slot mappings to avoid the
|
||||||
|
# out-of-range access during the model execution. The draft tokens
|
||||||
|
# generated with this adjustment should be ignored.
|
||||||
|
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
|
||||||
|
# Mask out the position ids that exceed the max model length.
|
||||||
|
# Otherwise, we may get out-of-range error in RoPE.
|
||||||
|
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||||
|
positions)
|
||||||
|
# Increment the sequence lengths.
|
||||||
|
attn_metadata.seq_lens[:batch_size] += 1
|
||||||
|
# For the requests that exceed the max model length, we set the
|
||||||
|
# sequence length to 1 to minimize their overheads in attention.
|
||||||
|
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||||
|
attn_metadata.seq_lens.device, non_blocking=True)
|
||||||
|
attn_metadata.seq_lens[:batch_size].masked_fill_(
|
||||||
|
exceeds_max_model_len_cpu, 1)
|
||||||
|
# Mask out the slot mappings that exceed the max model length.
|
||||||
|
# Otherwise, the KV cache will be inadvertently updated with the
|
||||||
|
# padding tokens.
|
||||||
|
slot_mapping += 1
|
||||||
|
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||||
|
|
||||||
|
# copy inputs to buffer for cudagraph
|
||||||
|
self.input_ids[:batch_size] = input_ids
|
||||||
|
self.positions[:batch_size] = clamped_positions
|
||||||
|
self.hidden_states[:batch_size] = hidden_states
|
||||||
|
attn_metadata.slot_mapping[:batch_size] = slot_mapping
|
||||||
|
|
||||||
|
if attn_metadata.prefill is not None:
|
||||||
|
attn_metadata.prefill.seq_lens = attn_metadata.seq_lens
|
||||||
|
attn_metadata.prefill.context_lens = attn_metadata.seq_lens
|
||||||
|
attn_metadata.prefill.input_positions = self.positions[:
|
||||||
|
num_input_tokens]
|
||||||
|
attn_metadata.prefill.max_seq_lens += 1
|
||||||
|
attn_metadata.prefill.max_seq_lens = min(
|
||||||
|
attn_metadata.prefill.max_seq_lens,
|
||||||
|
self.runner.model_config.max_model_len)
|
||||||
|
if attn_metadata.decode is not None:
|
||||||
|
attn_metadata.decode.seq_lens = attn_metadata.seq_lens
|
||||||
|
attn_metadata.decode.input_positions = self.positions[:
|
||||||
|
num_input_tokens]
|
||||||
|
attn_metadata.decode.max_seq_lens += 1
|
||||||
|
attn_metadata.decode.max_seq_lens = min(
|
||||||
|
attn_metadata.decode.max_seq_lens,
|
||||||
|
self.runner.model_config.max_model_len)
|
||||||
|
|
||||||
|
# mtp>1: [batch_size, k]
|
||||||
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||||
|
return draft_token_ids
|
||||||
|
|
||||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||||
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
||||||
-1]:
|
-1]:
|
||||||
|
|||||||
@@ -1020,7 +1020,6 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
input_layout = "BNSD"
|
input_layout = "BNSD"
|
||||||
|
|
||||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||||
assert num_tokens % self.spec_token_num == 0
|
|
||||||
input_layout = "TND"
|
input_layout = "TND"
|
||||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||||
# isort: skip_file
|
# isort: skip_file
|
||||||
|
|
||||||
|
import math
|
||||||
import types
|
import types
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -427,6 +428,11 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||||
cur_graph_batch_size = (graph_batch_size + tp_size -
|
cur_graph_batch_size = (graph_batch_size + tp_size -
|
||||||
1) // tp_size * tp_size
|
1) // tp_size * tp_size
|
||||||
|
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
|
||||||
|
# Both adapter multi-dp and FIA operator
|
||||||
|
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
|
||||||
|
cur_graph_batch_size = (tp_size * graph_batch_size) \
|
||||||
|
// math.gcd(tp_size, graph_batch_size)
|
||||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user