Input_embeds support (#2052)
This commit is contained in:
@@ -606,9 +606,17 @@ class ModelRunner:
|
||||
def forward_extend(self, forward_batch: ForwardBatch):
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
if forward_batch.input_embeds is None:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
input_embeds=forward_batch.input_embeds.bfloat16(),
|
||||
)
|
||||
else:
|
||||
# Only embedding models have get_embedding parameter
|
||||
return self.model.forward(
|
||||
|
||||
Reference in New Issue
Block a user