feat: add mtp ut and fix some bugs (#2453)
### What this PR does / why we need it?
Fix mtp mode ut
### Does this PR introduce _any_ user-facing change?
Nothing
### How was this patch tested?
This can be tested in the same way as a unit test.
- vLLM version: v0.10.0
- vLLM main:
53415653ff
Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
@@ -190,11 +190,6 @@ class MtpProposer:
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
attn_metadata.prefill.query_lens = query_lens.cpu()
|
||||
attn_metadata.prefill.input_positions = target_positions
|
||||
attn_metadata.prefill.seq_lens = seq_lens
|
||||
|
||||
if not self.torchair_graph_enabled:
|
||||
# torch mode need to update num_tokens_across_dp
|
||||
# TODO: adapt enable_dbo later
|
||||
@@ -213,6 +208,7 @@ class MtpProposer:
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
@@ -315,6 +311,7 @@ class MtpProposer:
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0):
|
||||
if is_running_torchair:
|
||||
|
||||
@@ -47,9 +47,14 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (init_ascend_soc_version,
|
||||
register_ascend_customop, sleep_mode_enabled,
|
||||
try_register_lib)
|
||||
try_register_lib, vllm_version_is)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if not vllm_version_is("0.10.1.1"):
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
else:
|
||||
DraftTokenIds = None
|
||||
|
||||
|
||||
class NPUWorker(WorkerBase):
|
||||
|
||||
@@ -343,3 +348,6 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
return self.model_runner.take_draft_token_ids()
|
||||
Reference in New Issue
Block a user