[Fix]: support deepseek-vl2-tiny model (#5552)

Co-authored-by: bppps <zouyu.zzx@alibaba-inc.com>
This commit is contained in:
ZXN
2025-04-26 17:52:53 +08:00
committed by GitHub
parent feda9b11b3
commit 04d0123fd9
6 changed files with 80 additions and 6 deletions

View File

@@ -382,8 +382,14 @@ class DeepseekModel(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
@@ -416,14 +422,18 @@ class DeepseekForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)

View File

@@ -18,6 +18,7 @@ from sglang.srt.managers.mm_utils import (
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek import DeepseekForCausalLM
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
@@ -189,7 +190,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
# ----------- language model ------------
language_config = config.language_config
self.language_model = DeepseekV2ForCausalLM(language_config)
if language_config.use_mla:
self.language_model = DeepseekV2ForCausalLM(language_config)
else:
# deepseek-vl2-tiny forbids mla
self.language_model = DeepseekForCausalLM(language_config)
def _init_vision_module(
self, vision_config, quant_config: Optional[QuantizationConfig]