[Feature] Support Deepseek-VL2 (#2798)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
萝卜菜
2025-03-17 14:07:59 +08:00
committed by GitHub
parent 0212d2e288
commit d6d21640d3
13 changed files with 1259 additions and 2 deletions

View File

@@ -1021,6 +1021,7 @@ class DeepseekV2Model(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
# Gather
@@ -1035,7 +1036,11 @@ class DeepseekV2Model(nn.Module):
)
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
hidden_states = self.embed_tokens(input_ids)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
@@ -1076,8 +1081,10 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.