Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Chayenne
2024-08-26 01:29:12 +08:00
committed by GitHub
parent 66e7dcaf70
commit 30b4f771b0
15 changed files with 167 additions and 55 deletions

View File

@@ -204,7 +204,7 @@ class ModelRunner:
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
logger.info(
@@ -522,9 +522,18 @@ class ModelRunner:
batch,
forward_mode=ForwardMode.EXTEND,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
if self.is_generation:
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
get_embedding=True,
)
@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):