[MTP] follow custom deepseek modeling changes to support graph mode (#636)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?

As custom deepseek modeling do some changes to support graph mode in
https://github.com/vllm-project/vllm-ascend/pull/585, so i follow it to
change custom deepseek_mtp modeling.

And some modifications for k>1 were not carried over by the
https://github.com/vllm-project/vllm-ascend/pull/429, now i add it.

In order to better take care of the MTP feature in the vllm-ascend
repository, I added cases related to graph mode(torchair), but i skip it
since torchair can not correctly clean up memory in vllmrunner.

Also i add some case for MTP quantization weights, but test weight is
not ready, so i skip it and i will open it when test quant weights is
ready.

https://github.com/vllm-project/vllm-ascend/pull/648 did not completely
fix the sample
change(https://github.com/vllm-project/vllm-ascend/issues/660) issue, I
added the relevant changes.

### Does this PR introduce _any_ user-facing change?
now, u can use following method to use mtp in deepseek v3/r1 float or
quant weights with eager mode.
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    enforce_eager=True,
    trust_remote_code=True,
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

or use mtp in deepseek v3/r1 float or quant weights with graph
mode(torchair)
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    trust_remote_code=True,
    additional_config={
        'enable_graph_mode': True,
    },
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

add notes:
1. now, we support k>1, so u can set num_speculative_tokens > 1 if there
is sufficient redundant computing power;
2. MTP is not supported in V1, we will support it when vLLM does it in
https://github.com/vllm-project/vllm/issues/13500.
3. if u run MTP failed by `segmentation fault`, u can follow v0.7.3
patch https://github.com/vllm-project/vllm-ascend/pull/236 file
`vllm_ascend/patch/patch_metrics.py` method
`__npu_async_metrics_collector_init__`

### How was this patch tested?
local tested passed and test by CI

Signed-off-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
wemaster
2025-04-28 21:18:53 +08:00
committed by GitHub
parent be9e3e8545
commit 54c0e63df7
15 changed files with 288 additions and 39 deletions

View File

@@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
@@ -61,15 +62,19 @@ def sampler_output(
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# TODO Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(model_output,
expanded_request)
self._append_new_tokens(model_output,
expanded_request.seq_group_metadata_list,
@@ -84,4 +89,22 @@ def sampler_output(
return filtered_model_outputs, True
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiSteoWorker
if hasattr(self.model_runner.model, "sampler"):
self.model_runner.model.sampler.include_gpu_probs_tensor = True
if not vllm_version_is("0.8.4"):
self.model_runner.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None:
if hasattr(self.model_runner.model, "sampler"):
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
if not vllm_version_is("0.8.4"):
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)
MultiStepWorker.set_include_gpu_probs_tensor = set_include_gpu_probs_tensor
MultiStepWorker.set_should_modify_greedy_probs_inplace = set_should_modify_greedy_probs_inplace

View File

@@ -93,7 +93,7 @@ def create_worker(
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
num_spec_prefill_steps = draft_model_config.hf_config.n_predict
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)