[feat] Support EAGLE3 for Qwen (#7745)
Co-authored-by: 纬杭 <ximing.wxm@antgroup.com> Co-authored-by: zyksir <zyksir@outlook.com>
This commit is contained in:
@@ -293,6 +293,9 @@ class Qwen2Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer(return_tuple=True)
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
|
# For EAGLE3 support
|
||||||
|
self.layers_to_capture = []
|
||||||
|
|
||||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
if hasattr(self.config, "scale_emb"):
|
if hasattr(self.config, "scale_emb"):
|
||||||
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
|
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
|
||||||
@@ -321,7 +324,12 @@ class Qwen2Model(nn.Module):
|
|||||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
residual = pp_proxy_tensors["residual"]
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
|
aux_hidden_states = []
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
if i in self.layers_to_capture:
|
||||||
|
aux_hidden_states.append(
|
||||||
|
hidden_states + residual if residual is not None else hidden_states
|
||||||
|
)
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
@@ -342,7 +350,11 @@ class Qwen2Model(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
|
||||||
|
if len(aux_hidden_states) == 0:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
|
|
||||||
# If this function is called, it should always initialize KV cache scale
|
# If this function is called, it should always initialize KV cache scale
|
||||||
# factors (or else raise an exception). Thus, handled exceptions should
|
# factors (or else raise an exception). Thus, handled exceptions should
|
||||||
|
|||||||
@@ -440,6 +440,9 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer(return_tuple=True)
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
|
# For EAGLE3 support
|
||||||
|
self.layers_to_capture = []
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@@ -459,6 +462,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
residual = pp_proxy_tensors["residual"]
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
|
aux_hidden_states = []
|
||||||
if forward_batch.can_run_tbo:
|
if forward_batch.can_run_tbo:
|
||||||
hidden_states, residual = model_forward_maybe_tbo(
|
hidden_states, residual = model_forward_maybe_tbo(
|
||||||
layers=self.layers,
|
layers=self.layers,
|
||||||
@@ -471,6 +475,12 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
if i in self.layers_to_capture:
|
||||||
|
aux_hidden_states.append(
|
||||||
|
hidden_states + residual
|
||||||
|
if residual is not None
|
||||||
|
else hidden_states
|
||||||
|
)
|
||||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
@@ -489,7 +499,11 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
|
||||||
|
if len(aux_hidden_states) == 0:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeForCausalLM(nn.Module):
|
class Qwen2MoeForCausalLM(nn.Module):
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -325,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
|
# For EAGLE3 support
|
||||||
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
@@ -346,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aux_hidden_states = None
|
||||||
|
if self.capture_aux_hidden_states:
|
||||||
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids,
|
||||||
|
hidden_states,
|
||||||
|
self.lm_head,
|
||||||
|
forward_batch,
|
||||||
|
aux_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
@@ -447,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
self.model.load_kv_cache_scales(quantization_param_path)
|
self.model.load_kv_cache_scales(quantization_param_path)
|
||||||
|
|
||||||
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.capture_aux_hidden_states = True
|
||||||
|
if layer_ids is None:
|
||||||
|
num_layers = self.config.num_hidden_layers
|
||||||
|
self.model.layers_to_capture = [
|
||||||
|
2,
|
||||||
|
num_layers // 2,
|
||||||
|
num_layers - 3,
|
||||||
|
] # Specific layers for EAGLE3 support
|
||||||
|
else:
|
||||||
|
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Qwen3ForCausalLM
|
EntryClass = Qwen3ForCausalLM
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -717,6 +717,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -735,9 +736,13 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aux_hidden_states = None
|
||||||
|
if self.capture_aux_hidden_states:
|
||||||
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -750,6 +755,24 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
def end_layer(self):
|
def end_layer(self):
|
||||||
return self.model.end_layer
|
return self.model.end_layer
|
||||||
|
|
||||||
|
def get_embed_and_head(self):
|
||||||
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|
||||||
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.capture_aux_hidden_states = True
|
||||||
|
if layer_ids is None:
|
||||||
|
num_layers = self.config.num_hidden_layers
|
||||||
|
self.model.layers_to_capture = [
|
||||||
|
2,
|
||||||
|
num_layers // 2,
|
||||||
|
num_layers - 3,
|
||||||
|
] # Specific layers for EAGLE3 support
|
||||||
|
else:
|
||||||
|
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user