Input_embeds support (#2052)

This commit is contained in:
Rin Intachuen
2024-11-25 19:35:04 -05:00
committed by GitHub
parent 1f76fc6e3f
commit 1aea19f64b
9 changed files with 204 additions and 15 deletions

View File

@@ -130,6 +130,9 @@ class ForwardBatch:
# For LoRA
lora_paths: Optional[List[str]] = None
# For input embeddings
input_embeds: Optional[torch.tensor] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
@@ -231,6 +234,7 @@ class ForwardBatch:
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
input_embeds=batch.input_embeds,
)
if ret.global_num_tokens is not None:

View File

@@ -606,9 +606,17 @@ class ModelRunner:
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
if forward_batch.input_embeds is None:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
input_embeds=forward_batch.input_embeds.bfloat16(),
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(