[feat] Support EAGLE3 for Qwen2 (#9216)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
# Modify details for the adaptation of Qwen2 model.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
else:
|
||||
# ranks other than the last rank will have a placeholder layer
|
||||
self.lm_head = PPMissingLayer()
|
||||
@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
# For EAGLE3 support
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embedding(input_ids)
|
||||
@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
aux_hidden_states = None
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids,
|
||||
hidden_states,
|
||||
self.lm_head,
|
||||
forward_batch,
|
||||
aux_hidden_states,
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||
if not self.pp_group.is_last_rank:
|
||||
return
|
||||
|
||||
self.capture_aux_hidden_states = True
|
||||
if layer_ids is None:
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [
|
||||
2,
|
||||
num_layers // 2,
|
||||
num_layers - 3,
|
||||
] # Specific layers for EAGLE3 support
|
||||
else:
|
||||
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||
|
||||
|
||||
EntryClass = Qwen2ForCausalLM
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
# For EAGLE3 support
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
aux_hidden_states = None
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
if self.pp_group.is_last_rank:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
||||
)
|
||||
else:
|
||||
return hidden_states
|
||||
@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
num_groups=None,
|
||||
)
|
||||
|
||||
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||
if not self.pp_group.is_last_rank:
|
||||
return
|
||||
|
||||
self.capture_aux_hidden_states = True
|
||||
if layer_ids is None:
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [
|
||||
2,
|
||||
num_layers // 2,
|
||||
num_layers - 3,
|
||||
] # Specific layers for EAGLE3 support
|
||||
else:
|
||||
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||
|
||||
|
||||
EntryClass = Qwen2MoeForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user