Qwen2.5-VL eagle3 infer (#8801)

This commit is contained in:
Lzhang-hub
2025-09-08 11:44:34 +08:00
committed by GitHub
parent 7802586cab
commit 37d83c6e6d
9 changed files with 114 additions and 5 deletions

View File

@@ -454,6 +454,9 @@ class Qwen2ForCausalLM(nn.Module):
# For EAGLE3 support
self.capture_aux_hidden_states = False
# For EAGLE3 support
self.capture_aux_hidden_states = False
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids)
@@ -481,6 +484,10 @@ class Qwen2ForCausalLM(nn.Module):
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(