Input_embeds support (#2052)

This commit is contained in:
Rin Intachuen
2024-11-25 19:35:04 -05:00
committed by GitHub
parent 1f76fc6e3f
commit 1aea19f64b
9 changed files with 204 additions and 15 deletions

View File

@@ -606,9 +606,17 @@ class ModelRunner:
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
if forward_batch.input_embeds is None:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
input_embeds=forward_batch.input_embeds.bfloat16(),
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(