Add e5-mistral embedding model - step 3/3 (#988)

This commit is contained in:
Ying Sheng
2024-08-08 16:31:19 -07:00
committed by GitHub
parent 9f662501a3
commit e040a2450b
14 changed files with 474 additions and 241 deletions

View File

@@ -25,7 +25,11 @@ import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
@@ -66,6 +70,18 @@ class DetokenizerManager:
async def handle_loop(self):
while True:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchEmbeddingOut):
self.send_to_tokenizer.send_pyobj(
BatchEmbeddingOut(
rids=recv_obj.rids,
embeddings=recv_obj.embeddings,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
)
)
continue
assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)

View File

@@ -143,6 +143,7 @@ class Req:
# Logprobs
self.return_logprob = False
self.embedding = None
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None

View File

@@ -21,7 +21,7 @@ import dataclasses
import logging
import multiprocessing as mp
import os
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
import numpy as np
import transformers
@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_multimodal_model, load_image
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -85,6 +88,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
self.is_generation = is_generation_model(self.hf_config.architectures)
if server_args.context_length is not None:
self.context_len = server_args.context_length
@@ -133,7 +137,9 @@ class TokenizerManager:
image_data, aspect_ratio, grid_pinpoints, self.processor
)
async def generate_request(self, obj: GenerateReqInput, request=None):
async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
):
if self.to_create_loop:
self.create_handle_loop()
@@ -144,6 +150,8 @@ class TokenizerManager:
async for response in self._handle_single_request(obj, request):
yield response
else:
if isinstance(obj, EmbeddingReqInput):
raise NotImplementedError("Please send only one prompt in each request")
if obj.stream:
raise ValueError("Do not support stream for batch mode.")
@@ -151,39 +159,47 @@ class TokenizerManager:
yield response
async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request,
index=None,
is_cache_for_prefill=False,
):
if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = (
self.tokenizer.encode(input_text)
if obj.input_ids is None
else obj.input_ids
)
if not not_use_index and obj.input_ids:
input_ids = obj.input_ids[index]
if obj.input_ids is None:
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
self._validate_input_length(input_ids)
sampling_params = self._get_sampling_params(
obj.sampling_params if not_use_index else obj.sampling_params[index]
)
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
)
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len
if not_use_index
else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num
if not_use_index
else obj.top_logprobs_num[index]
)
else: # A prefill request to cache the common prompt for parallel sampling
assert self.is_generation
if obj.text is not None:
if isinstance(obj.text, list):
input_text = obj.text[index]
@@ -213,19 +229,28 @@ class TokenizerManager:
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
)
if self.is_generation:
tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
)
else: # is embedding
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
@@ -368,7 +393,7 @@ class TokenizerManager:
self,
event: asyncio.Event,
state: ReqState,
obj: GenerateReqInput,
obj: Union[GenerateReqInput, EmbeddingReqInput],
rid: str,
request,
):
@@ -381,12 +406,15 @@ class TokenizerManager:
raise ValueError(f"Abort request {rid}")
continue
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
if self.is_generation:
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
else: # isinstance(obj, EmbeddingReqInput)
out = state.out_list[-1]
# Log requests
if self.server_args.log_requests and state.finished:
@@ -459,8 +487,10 @@ class TokenizerManager:
async def handle_loop(self):
while True:
recv_obj: BatchStrOut = await self.recv_from_detokenizer.recv_pyobj()
assert isinstance(recv_obj, BatchStrOut)
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
await self.recv_from_detokenizer.recv_pyobj()
)
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
@@ -468,10 +498,17 @@ class TokenizerManager:
continue
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()

View File

@@ -20,7 +20,7 @@ import multiprocessing
import pickle
import time
import warnings
from typing import List, Optional
from typing import List, Optional, Union
import torch
import torch.distributed as dist
@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
FlushCacheReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
@@ -205,7 +207,9 @@ class ModelTpServer:
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
if isinstance(
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
@@ -297,41 +301,42 @@ class ModelTpServer:
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
(
req.origin_input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
req.sampling_params = recv_req.sampling_params
if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
(
req.origin_input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
)
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
@@ -340,14 +345,17 @@ class ModelTpServer:
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
if self.model_runner.is_generation:
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
@@ -439,47 +447,68 @@ class ModelTpServer:
self.model_config.vocab_size, self.int_token_logit_bias
)
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
output.next_token_logprobs = output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
next_token_ids = next_token_ids.tolist()
else:
self.tree_cache.cache_unfinished_req(req)
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
pt += req.extend_input_len
else:
assert batch.extend_num_tokens != 0
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
embeddings = output.embeddings.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
# Inflight request would get a new req idx
self.req_to_token_pool.free(req.req_pool_idx)
self.handle_finished_requests(batch)
@@ -596,15 +625,19 @@ class ModelTpServer:
def handle_finished_requests(self, batch: ScheduleBatch):
output_rids = []
output_vids = []
decoded_texts = []
output_read_ids = []
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
if self.model_runner.is_generation:
output_vids = []
decoded_texts = []
output_read_ids = []
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
else: # for embedding model
output_embeddings = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i)
@@ -619,56 +652,73 @@ class ModelTpServer:
)
):
output_rids.append(req.rid)
output_vids.append(req.vid)
decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids)
output_read_offsets.append(read_offset)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason),
}
if req.return_logprob:
(
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
output_finished_reason.append(req.finished_reason)
if self.model_runner.is_generation:
output_vids.append(req.vid)
decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids)
output_read_offsets.append(read_offset)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason),
}
if req.return_logprob:
(
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
else: # for embedding model
output_embeddings.append(req.embedding)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
}
output_meta_info.append(meta_info)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_vids,
decoded_texts,
output_read_ids,
output_read_offsets,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
if self.model_runner.is_generation:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_vids,
decoded_texts,
output_read_ids,
output_read_offsets,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
)
)
else: # for embedding model
self.out_pyobjs.append(
BatchEmbeddingOut(
output_rids,
output_embeddings,
output_meta_info,
output_finished_reason,
)
)
)
# Remove finished reqs: update batch tensors
batch.filter_batch(unfinished_indices)

View File

@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
is_generation_model,
is_llama3_405b_fp8,
is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader,
@@ -132,8 +133,10 @@ class ModelRunner:
self.init_cublas()
self.init_flashinfer()
# Capture cuda graphs
self.init_cuda_graphs()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self.init_cuda_graphs()
def load_model(self):
logger.info(
@@ -184,6 +187,10 @@ class ModelRunner:
scheduler_config=None,
cache_config=None,
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
)
logger.info(
f"[gpu={self.gpu_id}] Load weight end. "
f"type={type(self.model).__name__}, "
@@ -406,8 +413,10 @@ def import_model_classes():
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert tmp.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert entry.__name__ not in model_arch_name_to_cls
model_arch_name_to_cls[entry.__name__] = entry
# compat: some models such as chatglm has incorrect class set in config.json
@@ -417,6 +426,7 @@ def import_model_classes():
):
for remap in module.EntryClassRemapping:
if isinstance(remap, tuple) and len(remap) == 2:
assert remap[0] not in model_arch_name_to_cls
model_arch_name_to_cls[remap[0]] = remap[1]
return model_arch_name_to_cls

View File

@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
EntryClass = LlamaEmbeddingModel
# compat: e5-mistral model.config class == MistralModel
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]

View File

@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
@@ -97,6 +97,7 @@ async def health() -> Response:
async def get_model_info():
result = {
"model_path": tokenizer_manager.model_path,
"is_generation": tokenizer_manager.is_generation,
}
return result
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request)
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
except (AssertionError, requests.exceptions.RequestException) as e:
last_traceback = get_exception_traceback()
pass
model_info = res.json()
if not success:
if pipe_finish_writer is not None:
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
sys.exit(1)
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 0
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
url + request_name,
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
"max_new_tokens": max_new_tokens,
},
},
headers=headers,
@@ -529,5 +548,18 @@ class Runtime:
)
return json.dumps(response.json())
def encode(
self,
prompt: str,
):
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
return json.dumps(response.json())
def __del__(self):
self.shutdown()

View File

@@ -223,6 +223,15 @@ def is_multimodal_model(model):
raise ValueError("unrecognized type")
def is_generation_model(model_architectures):
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
):
return False
return True
def decode_video_base64(video_base64):
from PIL import Image

View File

@@ -23,6 +23,7 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [
"The capital of France is",
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
NUM_TOP_LOGPROBS = 5
def is_embedding_model(model_path):
# FIXME incomplete list
if "e5-mistral-7b-instruct" in model_path.lower():
return True
return False
def get_dtype_str(torch_dtype):
if torch_dtype is torch.float16:
return "float16"
@@ -60,7 +54,7 @@ class HFRunner:
self,
model_path,
torch_dtype=torch.float16,
is_embedding_model=None,
is_generation_model=None,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
@@ -72,13 +66,13 @@ class HFRunner:
self.out_queue,
model_path,
torch_dtype,
is_embedding_model,
is_generation_model,
),
)
self.model_proc.start()
def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
@@ -86,12 +80,12 @@ class HFRunner:
trust_remote_code=True,
)
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
if not self.is_embedding_model:
if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
@@ -103,13 +97,13 @@ class HFRunner:
self.model = SentenceTransformer(
model_path,
device="cpu",
).to(dtype=torch_dtype)
model_kwargs={"torch_dtype": torch_dtype},
)
while True:
prompts, max_new_tokens = in_queue.get()
if prompts is not None:
if not self.is_embedding_model:
if self.is_generation_model:
output_strs = []
prefill_logprobs = []
for p in prompts:
@@ -144,7 +138,6 @@ class HFRunner:
)
else:
assert isinstance(prompts, List[str])
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
@@ -175,16 +168,13 @@ class SRTRunner:
model_path,
tp_size=1,
torch_dtype=torch.float16,
is_embedding_model=None,
is_generation_model=None,
):
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
if self.is_embedding_model:
raise NotImplementedError()
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
@@ -196,38 +186,45 @@ class SRTRunner:
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64,
):
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
response = self.runtime.generate(
prompt,
sampling_params=sampling_params,
return_logprob=True,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
if self.is_generation_model:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
response = self.runtime.generate(
prompt,
sampling_params=sampling_params,
return_logprob=True,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
]
)
# print(response["meta_info"]["output_top_logprobs"][0])
)
return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
else:
logits = []
for prompt in prompts:
response = self.runtime.encode(prompt)
response = json.loads(response)
logits.append(response["embedding"])
return ModelOutput(embed_logits=logits)
def __enter__(self):
return self

View File

@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
import numpy as np
import requests
import torch
import torch.nn.functional as F
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
return 0 if success else -1
def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)