[Feature] Layer-wise Prefill (#7634)
Signed-off-by: jason-fxz <jason341132@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1328,6 +1328,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
|
||||
def prepare_for_split_prefill(self):
|
||||
self.prepare_for_extend()
|
||||
# For split prefill, we need to set the forward mode to SPLIT_PREFILL
|
||||
self.forward_mode = ForwardMode.SPLIT_PREFILL
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
self.forward_mode = ForwardMode.MIXED
|
||||
running_bs = running_batch.batch_size()
|
||||
|
||||
@@ -68,6 +68,8 @@ class ForwardMode(IntEnum):
|
||||
MIXED = auto()
|
||||
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
|
||||
IDLE = auto()
|
||||
# Split Prefill for PD multiplexing
|
||||
SPLIT_PREFILL = auto()
|
||||
|
||||
# Used in speculative decoding: verify a batch in the target model.
|
||||
TARGET_VERIFY = auto()
|
||||
@@ -95,6 +97,9 @@ class ForwardMode(IntEnum):
|
||||
def is_mixed(self):
|
||||
return self == ForwardMode.MIXED
|
||||
|
||||
def is_split_prefill(self):
|
||||
return self == ForwardMode.SPLIT_PREFILL
|
||||
|
||||
def is_idle(self):
|
||||
return self == ForwardMode.IDLE
|
||||
|
||||
@@ -194,6 +199,14 @@ class ForwardBatch:
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For split prefill
|
||||
# intermediate values for split prefill
|
||||
hidden_states: torch.Tensor = None
|
||||
residual: torch.Tensor = None
|
||||
model_specific_states: Dict[str, any] = None
|
||||
# current split index of layer
|
||||
split_index: int = 0
|
||||
|
||||
# For MLA chunked prefix cache used in chunked prefill
|
||||
# Tell attention backend whether the kv cache needs to be attended in current pass
|
||||
attn_attend_prefix_cache: Optional[bool] = None
|
||||
|
||||
@@ -1513,11 +1513,34 @@ class ModelRunner:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
reinit_attn_backend: bool = False,
|
||||
forward_count: int = 1,
|
||||
) -> LogitsProcessorOutput:
|
||||
if forward_batch.split_index == 0 or reinit_attn_backend:
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
next_split_index = min(
|
||||
forward_batch.split_index + forward_count,
|
||||
self.model_config.num_hidden_layers,
|
||||
)
|
||||
ret = self.model.forward_split_prefill(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
(forward_batch.split_index, next_split_index),
|
||||
)
|
||||
forward_batch.split_index = next_split_index
|
||||
return ret
|
||||
|
||||
def forward(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
skip_attn_backend_init: bool = False,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
reinit_attn_backend: bool = False,
|
||||
split_forward_count: int = 1,
|
||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||
self.forward_pass_id += 1
|
||||
|
||||
@@ -1526,7 +1549,11 @@ class ModelRunner:
|
||||
forward_batch,
|
||||
):
|
||||
output = self._forward_raw(
|
||||
forward_batch, skip_attn_backend_init, pp_proxy_tensors
|
||||
forward_batch,
|
||||
skip_attn_backend_init,
|
||||
pp_proxy_tensors,
|
||||
reinit_attn_backend,
|
||||
split_forward_count,
|
||||
)
|
||||
|
||||
if self.eplb_manager is not None:
|
||||
@@ -1539,6 +1566,8 @@ class ModelRunner:
|
||||
forward_batch: ForwardBatch,
|
||||
skip_attn_backend_init: bool,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors],
|
||||
reinit_attn_backend: bool = False,
|
||||
split_forward_count: int = 1,
|
||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||
can_run_cuda_graph = bool(
|
||||
forward_batch.forward_mode.is_cuda_graph()
|
||||
@@ -1559,6 +1588,12 @@ class ModelRunner:
|
||||
skip_attn_backend_init=skip_attn_backend_init,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
elif forward_batch.forward_mode.is_split_prefill():
|
||||
ret = self.forward_split_prefill(
|
||||
forward_batch,
|
||||
reinit_attn_backend=reinit_attn_backend,
|
||||
forward_count=split_forward_count,
|
||||
)
|
||||
elif forward_batch.forward_mode.is_idle():
|
||||
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
||||
else:
|
||||
|
||||
@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
forward_batch.hidden_states *= self.model.config.hidden_size**0.5
|
||||
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
forward_batch.hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids,
|
||||
forward_batch.hidden_states,
|
||||
self.model.embed_tokens,
|
||||
forward_batch,
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@@ -381,6 +381,57 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
|
||||
# Normalize
|
||||
normalizer = torch.tensor(
|
||||
self.model.config.hidden_size**0.5, dtype=torch.float16
|
||||
)
|
||||
forward_batch.hidden_states *= normalizer
|
||||
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
forward_batch.hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids,
|
||||
forward_batch.hidden_states,
|
||||
self.model.embed_tokens,
|
||||
forward_batch,
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
# return input_dim, output_dim
|
||||
if module_name in ["q_proj", "qkv_proj"]:
|
||||
|
||||
@@ -647,6 +647,69 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
|
||||
if positions.dim() == 1:
|
||||
positions = einops.rearrange(positions, "s -> 1 s")
|
||||
position_embeddings_global = self.model.rotary_emb(hidden_states, positions)
|
||||
position_embeddings_local = self.model.rotary_emb_local(
|
||||
hidden_states, positions
|
||||
)
|
||||
|
||||
forward_batch.hidden_states = hidden_states
|
||||
forward_batch.model_specific_states = {
|
||||
"positions": positions,
|
||||
"position_embeddings_global": position_embeddings_global,
|
||||
"position_embeddings_local": position_embeddings_local,
|
||||
}
|
||||
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.model.layers[i]
|
||||
layer_output = layer(
|
||||
positions=forward_batch.model_specific_states["positions"],
|
||||
position_embeddings_global=forward_batch.model_specific_states[
|
||||
"position_embeddings_global"
|
||||
],
|
||||
position_embeddings_local=forward_batch.model_specific_states[
|
||||
"position_embeddings_local"
|
||||
],
|
||||
hidden_states=forward_batch.hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
forward_batch.hidden_states = layer_output[0]
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states)
|
||||
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids,
|
||||
forward_batch.hidden_states,
|
||||
self.model.embed_tokens,
|
||||
forward_batch,
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@@ -480,6 +480,47 @@ class LlamaForCausalLM(nn.Module):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> Optional[LogitsProcessorOutput]:
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
forward_batch.hidden_states = hidden_states
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -286,6 +287,42 @@ class QWenLMHeadModel(nn.Module):
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
forward_batch.hidden_states = self.transformer.wte(input_ids)
|
||||
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.transformer.h[i]
|
||||
forward_batch.hidden_states = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
)
|
||||
|
||||
if end == self.transformer.config.num_hidden_layers:
|
||||
# norm
|
||||
forward_batch.hidden_states = self.transformer.ln_f(
|
||||
forward_batch.hidden_states
|
||||
)
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
forward_batch.hidden_states = hidden_states
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@@ -406,6 +406,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pp_group = get_pp_group()
|
||||
@@ -554,6 +555,49 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
forward_batch.hidden_states = hidden_states
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Adapted from qwen2.py
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
@@ -367,6 +366,47 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
forward_batch.hidden_states = hidden_states
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@@ -745,6 +745,49 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
forward_batch.hidden_states = hidden_states
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@@ -500,6 +500,7 @@ class TboForwardBatchPreparer:
|
||||
"capture_hidden_mode",
|
||||
"padded_static_len",
|
||||
"mrope_positions", # only used by qwen2-vl, thus not care
|
||||
"split_index", # for split prefill
|
||||
]:
|
||||
output_dict[key] = getattr(batch, key)
|
||||
if not batch.forward_mode.is_target_verify():
|
||||
|
||||
Reference in New Issue
Block a user