[feat] Support EAGLE3 for Qwen (#7745)

Co-authored-by: 纬杭 <ximing.wxm@antgroup.com>
Co-authored-by: zyksir <zyksir@outlook.com>
This commit is contained in:
Ximingwang-09
2025-07-05 10:50:28 +08:00
committed by GitHub
parent af5647748a
commit 1964c325de
4 changed files with 81 additions and 6 deletions

View File

@@ -293,6 +293,9 @@ class Qwen2Model(nn.Module):
else:
self.norm = PPMissingLayer(return_tuple=True)
# For EAGLE3 support
self.layers_to_capture = []
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(self.config, "scale_emb"):
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
@@ -321,7 +324,12 @@ class Qwen2Model(nn.Module):
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
@@ -342,7 +350,11 @@ class Qwen2Model(nn.Module):
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should