Input_embeds support (#2052)
This commit is contained in:
@@ -130,6 +130,9 @@ class ForwardBatch:
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
|
||||
# For input embeddings
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -231,6 +234,7 @@ class ForwardBatch:
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
input_embeds=batch.input_embeds,
|
||||
)
|
||||
|
||||
if ret.global_num_tokens is not None:
|
||||
|
||||
@@ -606,9 +606,17 @@ class ModelRunner:
|
||||
def forward_extend(self, forward_batch: ForwardBatch):
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
if forward_batch.input_embeds is None:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
input_embeds=forward_batch.input_embeds.bfloat16(),
|
||||
)
|
||||
else:
|
||||
# Only embedding models have get_embedding parameter
|
||||
return self.model.forward(
|
||||
|
||||
Reference in New Issue
Block a user