Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Chayenne
2024-08-26 01:29:12 +08:00
committed by GitHub
parent 66e7dcaf70
commit 30b4f771b0
15 changed files with 167 additions and 55 deletions

View File

@@ -14,7 +14,7 @@ limitations under the License.
"""
import json
import multiprocessing
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import List, Union
@@ -63,37 +63,35 @@ class HFRunner:
self,
model_path,
torch_dtype,
is_generation_model,
is_generation,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
self.is_generation = is_generation
self.model_proc = multiprocessing.Process(
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
self.model_proc = mp.Process(
target=self.start_model_process,
args=(
self.in_queue,
self.out_queue,
model_path,
torch_dtype,
is_generation_model,
),
)
self.model_proc.start()
def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
):
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
torch_dtype=torch_dtype,
)
self.is_generation_model = is_generation_model
if self.is_generation_model:
if self.is_generation:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
else:
@@ -107,7 +105,7 @@ class HFRunner:
while True:
prompts, max_new_tokens = in_queue.get()
if prompts is not None:
if self.is_generation_model:
if self.is_generation:
output_strs = []
prefill_logprobs = []
for p in prompts:
@@ -171,17 +169,19 @@ class SRTRunner:
self,
model_path,
torch_dtype,
is_generation_model,
is_generation,
tp_size=1,
port=5157,
):
self.is_generation_model = is_generation_model
self.is_generation = is_generation
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.7,
trust_remote_code=False,
is_embedding=not self.is_generation,
)
def forward(
@@ -189,7 +189,7 @@ class SRTRunner:
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
):
if self.is_generation_model:
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []