[MTP] follow custom deepseek modeling changes to support graph mode (#636)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? As custom deepseek modeling do some changes to support graph mode in https://github.com/vllm-project/vllm-ascend/pull/585, so i follow it to change custom deepseek_mtp modeling. And some modifications for k>1 were not carried over by the https://github.com/vllm-project/vllm-ascend/pull/429, now i add it. In order to better take care of the MTP feature in the vllm-ascend repository, I added cases related to graph mode(torchair), but i skip it since torchair can not correctly clean up memory in vllmrunner. Also i add some case for MTP quantization weights, but test weight is not ready, so i skip it and i will open it when test quant weights is ready. https://github.com/vllm-project/vllm-ascend/pull/648 did not completely fix the sample change(https://github.com/vllm-project/vllm-ascend/issues/660) issue, I added the relevant changes. ### Does this PR introduce _any_ user-facing change? now, u can use following method to use mtp in deepseek v3/r1 float or quant weights with eager mode. ```python llm = LLM( model="wemaster/deepseek_mtp_main_random_bf16", tensor_parallel_size=2, speculative_config={ "num_speculative_tokens": 1, }, enforce_eager=True, trust_remote_code=True, disable_log_stats=False, gpu_memory_utilization=0.8, max_model_len=64, ) ``` or use mtp in deepseek v3/r1 float or quant weights with graph mode(torchair) ```python llm = LLM( model="wemaster/deepseek_mtp_main_random_bf16", tensor_parallel_size=2, speculative_config={ "num_speculative_tokens": 1, }, trust_remote_code=True, additional_config={ 'enable_graph_mode': True, }, disable_log_stats=False, gpu_memory_utilization=0.8, max_model_len=64, ) ``` add notes: 1. now, we support k>1, so u can set num_speculative_tokens > 1 if there is sufficient redundant computing power; 2. MTP is not supported in V1, we will support it when vLLM does it in https://github.com/vllm-project/vllm/issues/13500. 3. if u run MTP failed by `segmentation fault`, u can follow v0.7.3 patch https://github.com/vllm-project/vllm-ascend/pull/236 file `vllm_ascend/patch/patch_metrics.py` method `__npu_async_metrics_collector_init__` ### How was this patch tested? local tested passed and test by CI Signed-off-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
@@ -17,11 +17,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -34,6 +35,7 @@ from vllm.model_executor.models.deepseek_mtp import (
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
||||
|
||||
@@ -69,6 +71,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
@@ -88,6 +92,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
@@ -125,14 +131,20 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
@@ -170,3 +182,19 @@ class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user