[1/N][UT][v1 MTP] add basic v1 mtp features (#890)
### What this PR does / why we need it? add basic v1 mtp features please merge it after https://github.com/vllm-project/vllm-ascend/pull/874 and https://github.com/vllm-project/vllm-ascend/pull/844. ### Does this PR introduce _any_ user-facing change? now, we supported basic v1 mtp, only supported tp only、eager mode and k=1 we will continue to expand more scenarios. ### How was this patch tested? local tested Signed-off-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
This commit is contained in:
@@ -59,8 +59,10 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
@@ -201,6 +203,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
elif self.speculative_config.method == "eagle":
|
||||
self.drafter = EagleProposer(self.vllm_config,
|
||||
self.device) # type: ignore
|
||||
elif self.speculative_config.method == 'deepseek_mtp':
|
||||
self.drafter = MtpProposer(self.vllm_config, self)
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}")
|
||||
@@ -216,6 +220,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# None in the first PP rank. The rest are set after load_model.
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
|
||||
@@ -591,18 +601,43 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
extra_builder_kwargs = {}
|
||||
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
# Add graph_pad_size here
|
||||
if self.enable_torchair_graph_mode:
|
||||
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
else:
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
|
||||
# Prepare input_ids
|
||||
@@ -836,6 +871,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
raise NotImplementedError(
|
||||
"eagle method for spec decode doesn't work on vllm-ascend currently"
|
||||
)
|
||||
elif self.speculative_config.method == 'deepseek_mtp':
|
||||
assert isinstance(self.drafter, MtpProposer)
|
||||
spec_token_ids = self._generate_mtp_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
||||
spec_decode_metadata, positions, num_scheduled_tokens,
|
||||
hidden_states, attn_metadata)
|
||||
return spec_token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -1126,7 +1167,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
self.drafter.load_model()
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
@@ -1333,3 +1374,73 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
return draft_token_ids
|
||||
|
||||
def _generate_mtp_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
positions: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: SpecDecodeMetadata,
|
||||
):
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = attn_metadata.slot_mapping
|
||||
cu_num_tokens = attn_metadata.query_start_loc
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
attn_metadata.query_start_loc,
|
||||
num_rejected_tokens,
|
||||
)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=attn_metadata.block_tables,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
Reference in New Issue
Block a user