Qwen2.5-VL eagle3 infer (#8801)
This commit is contained in:
@@ -454,6 +454,9 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# For EAGLE3 support
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
# For EAGLE3 support
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embedding(input_ids)
|
||||
|
||||
@@ -481,6 +484,10 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
|
||||
Reference in New Issue
Block a user