[Feature] Layer-wise Prefill (#7634)
Signed-off-by: jason-fxz <jason341132@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1328,6 +1328,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.model_config.vocab_size,
|
self.model_config.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_for_split_prefill(self):
|
||||||
|
self.prepare_for_extend()
|
||||||
|
# For split prefill, we need to set the forward mode to SPLIT_PREFILL
|
||||||
|
self.forward_mode = ForwardMode.SPLIT_PREFILL
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
self.forward_mode = ForwardMode.MIXED
|
self.forward_mode = ForwardMode.MIXED
|
||||||
running_bs = running_batch.batch_size()
|
running_bs = running_batch.batch_size()
|
||||||
|
|||||||
@@ -68,6 +68,8 @@ class ForwardMode(IntEnum):
|
|||||||
MIXED = auto()
|
MIXED = auto()
|
||||||
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
|
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
|
||||||
IDLE = auto()
|
IDLE = auto()
|
||||||
|
# Split Prefill for PD multiplexing
|
||||||
|
SPLIT_PREFILL = auto()
|
||||||
|
|
||||||
# Used in speculative decoding: verify a batch in the target model.
|
# Used in speculative decoding: verify a batch in the target model.
|
||||||
TARGET_VERIFY = auto()
|
TARGET_VERIFY = auto()
|
||||||
@@ -95,6 +97,9 @@ class ForwardMode(IntEnum):
|
|||||||
def is_mixed(self):
|
def is_mixed(self):
|
||||||
return self == ForwardMode.MIXED
|
return self == ForwardMode.MIXED
|
||||||
|
|
||||||
|
def is_split_prefill(self):
|
||||||
|
return self == ForwardMode.SPLIT_PREFILL
|
||||||
|
|
||||||
def is_idle(self):
|
def is_idle(self):
|
||||||
return self == ForwardMode.IDLE
|
return self == ForwardMode.IDLE
|
||||||
|
|
||||||
@@ -194,6 +199,14 @@ class ForwardBatch:
|
|||||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# For split prefill
|
||||||
|
# intermediate values for split prefill
|
||||||
|
hidden_states: torch.Tensor = None
|
||||||
|
residual: torch.Tensor = None
|
||||||
|
model_specific_states: Dict[str, any] = None
|
||||||
|
# current split index of layer
|
||||||
|
split_index: int = 0
|
||||||
|
|
||||||
# For MLA chunked prefix cache used in chunked prefill
|
# For MLA chunked prefix cache used in chunked prefill
|
||||||
# Tell attention backend whether the kv cache needs to be attended in current pass
|
# Tell attention backend whether the kv cache needs to be attended in current pass
|
||||||
attn_attend_prefix_cache: Optional[bool] = None
|
attn_attend_prefix_cache: Optional[bool] = None
|
||||||
|
|||||||
@@ -1513,11 +1513,34 @@ class ModelRunner:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
reinit_attn_backend: bool = False,
|
||||||
|
forward_count: int = 1,
|
||||||
|
) -> LogitsProcessorOutput:
|
||||||
|
if forward_batch.split_index == 0 or reinit_attn_backend:
|
||||||
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
|
next_split_index = min(
|
||||||
|
forward_batch.split_index + forward_count,
|
||||||
|
self.model_config.num_hidden_layers,
|
||||||
|
)
|
||||||
|
ret = self.model.forward_split_prefill(
|
||||||
|
forward_batch.input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
(forward_batch.split_index, next_split_index),
|
||||||
|
)
|
||||||
|
forward_batch.split_index = next_split_index
|
||||||
|
return ret
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
skip_attn_backend_init: bool = False,
|
skip_attn_backend_init: bool = False,
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
reinit_attn_backend: bool = False,
|
||||||
|
split_forward_count: int = 1,
|
||||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||||
self.forward_pass_id += 1
|
self.forward_pass_id += 1
|
||||||
|
|
||||||
@@ -1526,7 +1549,11 @@ class ModelRunner:
|
|||||||
forward_batch,
|
forward_batch,
|
||||||
):
|
):
|
||||||
output = self._forward_raw(
|
output = self._forward_raw(
|
||||||
forward_batch, skip_attn_backend_init, pp_proxy_tensors
|
forward_batch,
|
||||||
|
skip_attn_backend_init,
|
||||||
|
pp_proxy_tensors,
|
||||||
|
reinit_attn_backend,
|
||||||
|
split_forward_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.eplb_manager is not None:
|
if self.eplb_manager is not None:
|
||||||
@@ -1539,6 +1566,8 @@ class ModelRunner:
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
skip_attn_backend_init: bool,
|
skip_attn_backend_init: bool,
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors],
|
pp_proxy_tensors: Optional[PPProxyTensors],
|
||||||
|
reinit_attn_backend: bool = False,
|
||||||
|
split_forward_count: int = 1,
|
||||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||||
can_run_cuda_graph = bool(
|
can_run_cuda_graph = bool(
|
||||||
forward_batch.forward_mode.is_cuda_graph()
|
forward_batch.forward_mode.is_cuda_graph()
|
||||||
@@ -1559,6 +1588,12 @@ class ModelRunner:
|
|||||||
skip_attn_backend_init=skip_attn_backend_init,
|
skip_attn_backend_init=skip_attn_backend_init,
|
||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
elif forward_batch.forward_mode.is_split_prefill():
|
||||||
|
ret = self.forward_split_prefill(
|
||||||
|
forward_batch,
|
||||||
|
reinit_attn_backend=reinit_attn_backend,
|
||||||
|
forward_count=split_forward_count,
|
||||||
|
)
|
||||||
elif forward_batch.forward_mode.is_idle():
|
elif forward_batch.forward_mode.is_idle():
|
||||||
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
forward_batch.hidden_states = input_embeds
|
||||||
|
|
||||||
|
# Normalize the embedding by sqrt(hidden_size)
|
||||||
|
forward_batch.hidden_states *= self.model.config.hidden_size**0.5
|
||||||
|
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
forward_batch.residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
forward_batch.hidden_states, _ = self.model.norm(
|
||||||
|
forward_batch.hidden_states, forward_batch.residual
|
||||||
|
)
|
||||||
|
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
self.model.embed_tokens,
|
||||||
|
forward_batch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
|||||||
@@ -381,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
forward_batch.hidden_states = input_embeds
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
normalizer = torch.tensor(
|
||||||
|
self.model.config.hidden_size**0.5, dtype=torch.float16
|
||||||
|
)
|
||||||
|
forward_batch.hidden_states *= normalizer
|
||||||
|
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
forward_batch.residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
forward_batch.hidden_states, _ = self.model.norm(
|
||||||
|
forward_batch.hidden_states, forward_batch.residual
|
||||||
|
)
|
||||||
|
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
self.model.embed_tokens,
|
||||||
|
forward_batch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
# return input_dim, output_dim
|
# return input_dim, output_dim
|
||||||
if module_name in ["q_proj", "qkv_proj"]:
|
if module_name in ["q_proj", "qkv_proj"]:
|
||||||
|
|||||||
@@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|||||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
|
||||||
|
if positions.dim() == 1:
|
||||||
|
positions = einops.rearrange(positions, "s -> 1 s")
|
||||||
|
position_embeddings_global = self.model.rotary_emb(hidden_states, positions)
|
||||||
|
position_embeddings_local = self.model.rotary_emb_local(
|
||||||
|
hidden_states, positions
|
||||||
|
)
|
||||||
|
|
||||||
|
forward_batch.hidden_states = hidden_states
|
||||||
|
forward_batch.model_specific_states = {
|
||||||
|
"positions": positions,
|
||||||
|
"position_embeddings_global": position_embeddings_global,
|
||||||
|
"position_embeddings_local": position_embeddings_local,
|
||||||
|
}
|
||||||
|
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
layer_output = layer(
|
||||||
|
positions=forward_batch.model_specific_states["positions"],
|
||||||
|
position_embeddings_global=forward_batch.model_specific_states[
|
||||||
|
"position_embeddings_global"
|
||||||
|
],
|
||||||
|
position_embeddings_local=forward_batch.model_specific_states[
|
||||||
|
"position_embeddings_local"
|
||||||
|
],
|
||||||
|
hidden_states=forward_batch.hidden_states,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
forward_batch.hidden_states = layer_output[0]
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states)
|
||||||
|
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
self.model.embed_tokens,
|
||||||
|
forward_batch,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
|||||||
@@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
) -> Optional[LogitsProcessorOutput]:
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
forward_batch.hidden_states = input_embeds
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
forward_batch.residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
hidden_states, _ = self.model.norm(
|
||||||
|
forward_batch.hidden_states, forward_batch.residual
|
||||||
|
)
|
||||||
|
forward_batch.hidden_states = hidden_states
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_layer(self):
|
def start_layer(self):
|
||||||
return self.model.start_layer
|
return self.model.start_layer
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
||||||
|
|
||||||
|
import time
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
forward_batch.hidden_states = self.transformer.wte(input_ids)
|
||||||
|
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
layer = self.transformer.h[i]
|
||||||
|
forward_batch.hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.transformer.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
forward_batch.hidden_states = self.transformer.ln_f(
|
||||||
|
forward_batch.hidden_states
|
||||||
|
)
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
|||||||
@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
forward_batch.hidden_states = input_embeds
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
forward_batch.residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
hidden_states, _ = self.model.norm(
|
||||||
|
forward_batch.hidden_states, forward_batch.residual
|
||||||
|
)
|
||||||
|
forward_batch.hidden_states = hidden_states
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_layer(self):
|
def start_layer(self):
|
||||||
return self.model.start_layer
|
return self.model.start_layer
|
||||||
|
|||||||
@@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
@@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
forward_batch.hidden_states = input_embeds
|
||||||
|
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
forward_batch.residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
hidden_states, _ = self.model.norm(
|
||||||
|
forward_batch.hidden_states, forward_batch.residual
|
||||||
|
)
|
||||||
|
forward_batch.hidden_states = hidden_states
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_layer(self):
|
def start_layer(self):
|
||||||
return self.model.start_layer
|
return self.model.start_layer
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# Adapted from qwen2.py
|
# Adapted from qwen2.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
@@ -367,6 +366,47 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
forward_batch.hidden_states = input_embeds
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
forward_batch.residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
hidden_states, _ = self.model.norm(
|
||||||
|
forward_batch.hidden_states, forward_batch.residual
|
||||||
|
)
|
||||||
|
forward_batch.hidden_states = hidden_states
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_layer(self):
|
def start_layer(self):
|
||||||
return self.model.start_layer
|
return self.model.start_layer
|
||||||
|
|||||||
@@ -745,6 +745,49 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward_split_prefill(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
start, end = split_interval
|
||||||
|
# embed
|
||||||
|
if start == 0:
|
||||||
|
if input_embeds is None:
|
||||||
|
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
forward_batch.hidden_states = input_embeds
|
||||||
|
|
||||||
|
# decoder layer
|
||||||
|
for i in range(start, end):
|
||||||
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
|
layer = self.model.layers[i]
|
||||||
|
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||||
|
positions,
|
||||||
|
forward_batch.hidden_states,
|
||||||
|
forward_batch,
|
||||||
|
forward_batch.residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end == self.model.config.num_hidden_layers:
|
||||||
|
# norm
|
||||||
|
hidden_states, _ = self.model.norm(
|
||||||
|
forward_batch.hidden_states, forward_batch.residual
|
||||||
|
)
|
||||||
|
forward_batch.hidden_states = hidden_states
|
||||||
|
# logits process
|
||||||
|
result = self.logits_processor(
|
||||||
|
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_layer(self):
|
def start_layer(self):
|
||||||
return self.model.start_layer
|
return self.model.start_layer
|
||||||
|
|||||||
@@ -500,6 +500,7 @@ class TboForwardBatchPreparer:
|
|||||||
"capture_hidden_mode",
|
"capture_hidden_mode",
|
||||||
"padded_static_len",
|
"padded_static_len",
|
||||||
"mrope_positions", # only used by qwen2-vl, thus not care
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
||||||
|
"split_index", # for split prefill
|
||||||
]:
|
]:
|
||||||
output_dict[key] = getattr(batch, key)
|
output_dict[key] = getattr(batch, key)
|
||||||
if not batch.forward_mode.is_target_verify():
|
if not batch.forward_mode.is_target_verify():
|
||||||
|
|||||||
Reference in New Issue
Block a user