[refactor] refactor deepseek-related files (#2849)
### What this PR does / why we need it?
This PR deletes ~2K lines of code about deepseek modeling. It falls back
CustomDeepseekV2 modules to original vllm implementations and adapts
some modifications in vllm about deepseek and moe.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
E2E vllm serving with torchair graph mode and eager mode.
- vLLM version: v0.10.2
- vLLM main:
759ef49b15
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: yiz-liu <136800916+yiz-liu@users.noreply.github.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -23,7 +23,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -33,12 +34,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class CustomDeepSeekShareHead(SharedHead):
|
||||
|
||||
@@ -65,6 +65,7 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -75,10 +76,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "shared_head"))
|
||||
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -103,8 +102,6 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user