[feat] Support EAGLE3 for Qwen (#7745)
Co-authored-by: 纬杭 <ximing.wxm@antgroup.com> Co-authored-by: zyksir <zyksir@outlook.com>
This commit is contained in:
@@ -293,6 +293,9 @@ class Qwen2Model(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
# For EAGLE3 support
|
||||
self.layers_to_capture = []
|
||||
|
||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self.config, "scale_emb"):
|
||||
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
|
||||
@@ -321,7 +324,12 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
if i in self.layers_to_capture:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
@@ -342,7 +350,11 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
if len(aux_hidden_states) == 0:
|
||||
return hidden_states
|
||||
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
|
||||
Reference in New Issue
Block a user