Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -14,7 +14,7 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
@@ -63,37 +63,35 @@ class HFRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
is_generation,
|
||||
):
|
||||
self.in_queue = multiprocessing.Queue()
|
||||
self.out_queue = multiprocessing.Queue()
|
||||
self.is_generation = is_generation
|
||||
|
||||
self.model_proc = multiprocessing.Process(
|
||||
self.in_queue = mp.Queue()
|
||||
self.out_queue = mp.Queue()
|
||||
|
||||
self.model_proc = mp.Process(
|
||||
target=self.start_model_process,
|
||||
args=(
|
||||
self.in_queue,
|
||||
self.out_queue,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
),
|
||||
)
|
||||
self.model_proc.start()
|
||||
|
||||
def start_model_process(
|
||||
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
||||
):
|
||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.is_generation_model = is_generation_model
|
||||
|
||||
if self.is_generation_model:
|
||||
if self.is_generation:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
else:
|
||||
@@ -107,7 +105,7 @@ class HFRunner:
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
if prompts is not None:
|
||||
if self.is_generation_model:
|
||||
if self.is_generation:
|
||||
output_strs = []
|
||||
prefill_logprobs = []
|
||||
for p in prompts:
|
||||
@@ -171,17 +169,19 @@ class SRTRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
is_generation,
|
||||
tp_size=1,
|
||||
port=5157,
|
||||
):
|
||||
self.is_generation_model = is_generation_model
|
||||
self.is_generation = is_generation
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
mem_fraction_static=0.7,
|
||||
trust_remote_code=False,
|
||||
is_embedding=not self.is_generation,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -189,7 +189,7 @@ class SRTRunner:
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
):
|
||||
if self.is_generation_model:
|
||||
if self.is_generation:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
|
||||
Reference in New Issue
Block a user