[Feature] Layer-wise Prefill (#7634)
Signed-off-by: jason-fxz <jason341132@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -481,6 +481,47 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_split_prefill(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
split_interval: Tuple[int, int], # [start, end) 0-based
|
||||
input_embeds: torch.Tensor = None,
|
||||
):
|
||||
start, end = split_interval
|
||||
# embed
|
||||
if start == 0:
|
||||
if input_embeds is None:
|
||||
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
forward_batch.hidden_states = input_embeds
|
||||
# decoder layer
|
||||
for i in range(start, end):
|
||||
layer = self.model.layers[i]
|
||||
forward_batch.hidden_states, forward_batch.residual = layer(
|
||||
positions,
|
||||
forward_batch.hidden_states,
|
||||
forward_batch,
|
||||
forward_batch.residual,
|
||||
)
|
||||
|
||||
if end == self.model.config.num_hidden_layers:
|
||||
# norm
|
||||
hidden_states, _ = self.model.norm(
|
||||
forward_batch.hidden_states, forward_batch.residual
|
||||
)
|
||||
forward_batch.hidden_states = hidden_states
|
||||
# logits process
|
||||
result = self.logits_processor(
|
||||
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
Reference in New Issue
Block a user