[Fix]: support deepseek-vl2-tiny model (#5552)
Co-authored-by: bppps <zouyu.zzx@alibaba-inc.com>
This commit is contained in:
@@ -382,8 +382,14 @@ class DeepseekModel(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
@@ -416,14 +422,18 @@ class DeepseekForCausalLM(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from sglang.srt.managers.mm_utils import (
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek import DeepseekForCausalLM
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||
|
||||
|
||||
@@ -189,7 +190,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
|
||||
# ----------- language model ------------
|
||||
language_config = config.language_config
|
||||
self.language_model = DeepseekV2ForCausalLM(language_config)
|
||||
if language_config.use_mla:
|
||||
self.language_model = DeepseekV2ForCausalLM(language_config)
|
||||
else:
|
||||
# deepseek-vl2-tiny forbids mla
|
||||
self.language_model = DeepseekForCausalLM(language_config)
|
||||
|
||||
def _init_vision_module(
|
||||
self, vision_config, quant_config: Optional[QuantizationConfig]
|
||||
|
||||
Reference in New Issue
Block a user