Add e5-mistral embedding model - step 3/3 (#988)
This commit is contained in:
@@ -25,7 +25,11 @@ import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
||||
@@ -66,6 +70,18 @@ class DetokenizerManager:
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, BatchEmbeddingOut):
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(
|
||||
rids=recv_obj.rids,
|
||||
embeddings=recv_obj.embeddings,
|
||||
meta_info=recv_obj.meta_info,
|
||||
finished_reason=recv_obj.finished_reason,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||
bs = len(recv_obj.rids)
|
||||
|
||||
|
||||
@@ -143,6 +143,7 @@ class Req:
|
||||
|
||||
# Logprobs
|
||||
self.return_logprob = False
|
||||
self.embedding = None
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = 0
|
||||
self.normalized_prompt_logprob = None
|
||||
|
||||
@@ -21,7 +21,7 @@ import dataclasses
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import is_multimodal_model, load_image
|
||||
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
@@ -85,6 +88,7 @@ class TokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
self.is_generation = is_generation_model(self.hf_config.architectures)
|
||||
|
||||
if server_args.context_length is not None:
|
||||
self.context_len = server_args.context_length
|
||||
@@ -133,7 +137,9 @@ class TokenizerManager:
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
async def generate_request(self, obj: GenerateReqInput, request=None):
|
||||
async def generate_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
@@ -144,6 +150,8 @@ class TokenizerManager:
|
||||
async for response in self._handle_single_request(obj, request):
|
||||
yield response
|
||||
else:
|
||||
if isinstance(obj, EmbeddingReqInput):
|
||||
raise NotImplementedError("Please send only one prompt in each request")
|
||||
if obj.stream:
|
||||
raise ValueError("Do not support stream for batch mode.")
|
||||
|
||||
@@ -151,39 +159,47 @@ class TokenizerManager:
|
||||
yield response
|
||||
|
||||
async def _handle_single_request(
|
||||
self, obj, request, index=None, is_cache_for_prefill=False
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request,
|
||||
index=None,
|
||||
is_cache_for_prefill=False,
|
||||
):
|
||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||
not_use_index = index is None
|
||||
|
||||
rid = obj.rid if not_use_index else obj.rid[index]
|
||||
input_text = obj.text if not_use_index else obj.text[index]
|
||||
input_ids = (
|
||||
self.tokenizer.encode(input_text)
|
||||
if obj.input_ids is None
|
||||
else obj.input_ids
|
||||
)
|
||||
if not not_use_index and obj.input_ids:
|
||||
input_ids = obj.input_ids[index]
|
||||
if obj.input_ids is None:
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
||||
|
||||
self._validate_input_length(input_ids)
|
||||
|
||||
sampling_params = self._get_sampling_params(
|
||||
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
||||
)
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
)
|
||||
logprob_start_len = (
|
||||
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
|
||||
)
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
)
|
||||
logprob_start_len = (
|
||||
obj.logprob_start_len
|
||||
if not_use_index
|
||||
else obj.logprob_start_len[index]
|
||||
)
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num
|
||||
if not_use_index
|
||||
else obj.top_logprobs_num[index]
|
||||
)
|
||||
else: # A prefill request to cache the common prompt for parallel sampling
|
||||
assert self.is_generation
|
||||
if obj.text is not None:
|
||||
if isinstance(obj.text, list):
|
||||
input_text = obj.text[index]
|
||||
@@ -213,19 +229,28 @@ class TokenizerManager:
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
sampling_params,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
)
|
||||
if self.is_generation:
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
sampling_params,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
)
|
||||
else: # is embedding
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
@@ -368,7 +393,7 @@ class TokenizerManager:
|
||||
self,
|
||||
event: asyncio.Event,
|
||||
state: ReqState,
|
||||
obj: GenerateReqInput,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
rid: str,
|
||||
request,
|
||||
):
|
||||
@@ -381,12 +406,15 @@ class TokenizerManager:
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
if self.is_generation:
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
else: # isinstance(obj, EmbeddingReqInput)
|
||||
out = state.out_list[-1]
|
||||
|
||||
# Log requests
|
||||
if self.server_args.log_requests and state.finished:
|
||||
@@ -459,8 +487,10 @@ class TokenizerManager:
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj: BatchStrOut = await self.recv_from_detokenizer.recv_pyobj()
|
||||
assert isinstance(recv_obj, BatchStrOut)
|
||||
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
|
||||
await self.recv_from_detokenizer.recv_pyobj()
|
||||
)
|
||||
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
|
||||
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
@@ -468,10 +498,17 @@ class TokenizerManager:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.event.set()
|
||||
|
||||
@@ -20,7 +20,7 @@ import multiprocessing
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
FlushCacheReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
||||
@@ -205,7 +207,9 @@ class ModelTpServer:
|
||||
try:
|
||||
# Recv requests
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
if isinstance(
|
||||
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
):
|
||||
self.handle_generate_request(recv_req)
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
@@ -297,41 +301,42 @@ class ModelTpServer:
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||
):
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
req.pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_size = recv_req.image_size
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offset,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
)
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# Init regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||
if not self.disable_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
req.sampling_params.regex
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
if self.model_runner.is_generation:
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
req.pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_size = recv_req.image_size
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offset,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
)
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
|
||||
# Init regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||
if not self.disable_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
req.sampling_params.regex
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
@@ -340,14 +345,17 @@ class ModelTpServer:
|
||||
"the max context length. Truncated!!!"
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
(
|
||||
req.sampling_params.max_new_tokens
|
||||
if req.sampling_params.max_new_tokens is not None
|
||||
else 1 << 30
|
||||
),
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
)
|
||||
|
||||
if self.model_runner.is_generation:
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
(
|
||||
req.sampling_params.max_new_tokens
|
||||
if req.sampling_params.max_new_tokens is not None
|
||||
else 1 << 30
|
||||
),
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
)
|
||||
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
||||
@@ -439,47 +447,68 @@ class ModelTpServer:
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
if self.model_runner.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
output.next_token_logprobs = output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
||||
output.normalized_prompt_logprobs = (
|
||||
output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
output.next_token_logprobs = output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
||||
output.normalized_prompt_logprobs = (
|
||||
output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
|
||||
# Check finish conditions
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req is not self.current_inflight_req:
|
||||
# Inflight reqs' prefill is not finished
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# Inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
# Check finish conditions
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req is not self.current_inflight_req:
|
||||
# Inflight reqs' prefill is not finished
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||
pt += req.extend_input_len
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
else:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# Inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
if req.return_logprob:
|
||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||
pt += req.extend_input_len
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
embeddings = output.embeddings.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.embedding = embeddings[i]
|
||||
if req is not self.current_inflight_req:
|
||||
# Inflight reqs' prefill is not finished
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
else:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# Inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
@@ -596,15 +625,19 @@ class ModelTpServer:
|
||||
|
||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||
output_rids = []
|
||||
output_vids = []
|
||||
decoded_texts = []
|
||||
output_read_ids = []
|
||||
output_read_offsets = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_meta_info = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
if self.model_runner.is_generation:
|
||||
output_vids = []
|
||||
decoded_texts = []
|
||||
output_read_ids = []
|
||||
output_read_offsets = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
else: # for embedding model
|
||||
output_embeddings = []
|
||||
unfinished_indices = []
|
||||
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if not req.finished() and req is not self.current_inflight_req:
|
||||
unfinished_indices.append(i)
|
||||
@@ -619,56 +652,73 @@ class ModelTpServer:
|
||||
)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_vids.append(req.vid)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
read_ids, read_offset = req.init_incremental_detokenize()
|
||||
output_read_ids.append(read_ids)
|
||||
output_read_offsets.append(read_offset)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
output_spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
"completion_tokens": len(req.output_ids),
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"finish_reason": str(req.finished_reason),
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
meta_info["input_token_logprobs"],
|
||||
meta_info["output_token_logprobs"],
|
||||
meta_info["input_top_logprobs"],
|
||||
meta_info["output_top_logprobs"],
|
||||
meta_info["normalized_prompt_logprob"],
|
||||
) = (
|
||||
req.input_token_logprobs,
|
||||
req.output_token_logprobs,
|
||||
req.input_top_logprobs,
|
||||
req.output_top_logprobs,
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
output_finished_reason.append(req.finished_reason)
|
||||
if self.model_runner.is_generation:
|
||||
output_vids.append(req.vid)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
read_ids, read_offset = req.init_incremental_detokenize()
|
||||
output_read_ids.append(read_ids)
|
||||
output_read_offsets.append(read_offset)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
output_spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
"completion_tokens": len(req.output_ids),
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"finish_reason": str(req.finished_reason),
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
meta_info["input_token_logprobs"],
|
||||
meta_info["output_token_logprobs"],
|
||||
meta_info["input_top_logprobs"],
|
||||
meta_info["output_top_logprobs"],
|
||||
meta_info["normalized_prompt_logprob"],
|
||||
) = (
|
||||
req.input_token_logprobs,
|
||||
req.output_token_logprobs,
|
||||
req.input_top_logprobs,
|
||||
req.output_top_logprobs,
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
else: # for embedding model
|
||||
output_embeddings.append(req.embedding)
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
}
|
||||
output_meta_info.append(meta_info)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
self.out_pyobjs.append(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_vids,
|
||||
decoded_texts,
|
||||
output_read_ids,
|
||||
output_read_offsets,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
if self.model_runner.is_generation:
|
||||
self.out_pyobjs.append(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_vids,
|
||||
decoded_texts,
|
||||
output_read_ids,
|
||||
output_read_offsets,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
)
|
||||
)
|
||||
else: # for embedding model
|
||||
self.out_pyobjs.append(
|
||||
BatchEmbeddingOut(
|
||||
output_rids,
|
||||
output_embeddings,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished reqs: update batch tensors
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
is_generation_model,
|
||||
is_llama3_405b_fp8,
|
||||
is_multimodal_model,
|
||||
monkey_patch_vllm_dummy_weight_loader,
|
||||
@@ -132,8 +133,10 @@ class ModelRunner:
|
||||
self.init_cublas()
|
||||
self.init_flashinfer()
|
||||
|
||||
# Capture cuda graphs
|
||||
self.init_cuda_graphs()
|
||||
if self.is_generation:
|
||||
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
||||
# Capture cuda graphs
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
@@ -184,6 +187,10 @@ class ModelRunner:
|
||||
scheduler_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Load weight end. "
|
||||
f"type={type(self.model).__name__}, "
|
||||
@@ -406,8 +413,10 @@ def import_model_classes():
|
||||
entry, list
|
||||
): # To support multiple model classes in one module
|
||||
for tmp in entry:
|
||||
assert tmp.__name__ not in model_arch_name_to_cls
|
||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||
else:
|
||||
assert entry.__name__ not in model_arch_name_to_cls
|
||||
model_arch_name_to_cls[entry.__name__] = entry
|
||||
|
||||
# compat: some models such as chatglm has incorrect class set in config.json
|
||||
@@ -417,6 +426,7 @@ def import_model_classes():
|
||||
):
|
||||
for remap in module.EntryClassRemapping:
|
||||
if isinstance(remap, tuple) and len(remap) == 2:
|
||||
assert remap[0] not in model_arch_name_to_cls
|
||||
model_arch_name_to_cls[remap[0]] = remap[1]
|
||||
|
||||
return model_arch_name_to_cls
|
||||
|
||||
@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
|
||||
|
||||
EntryClass = LlamaEmbeddingModel
|
||||
# compat: e5-mistral model.config class == MistralModel
|
||||
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
|
||||
|
||||
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
load_chat_template_for_openai_api,
|
||||
@@ -97,6 +97,7 @@ async def health() -> Response:
|
||||
async def get_model_info():
|
||||
result = {
|
||||
"model_path": tokenizer_manager.model_path,
|
||||
"is_generation": tokenizer_manager.is_generation,
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
|
||||
app.put("/generate")(generate_request)
|
||||
|
||||
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
app.post("/encode")(encode_request)
|
||||
app.put("/encode")(encode_request)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
return await v1_completions(tokenizer_manager, raw_request)
|
||||
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
except (AssertionError, requests.exceptions.RequestException) as e:
|
||||
last_traceback = get_exception_traceback()
|
||||
pass
|
||||
model_info = res.json()
|
||||
|
||||
if not success:
|
||||
if pipe_finish_writer is not None:
|
||||
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
sys.exit(1)
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
max_new_tokens = 8 if model_info["is_generation"] else 0
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
url + request_name,
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
@@ -529,5 +548,18 @@ class Runtime:
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: str,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/encode",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
@@ -223,6 +223,15 @@ def is_multimodal_model(model):
|
||||
raise ValueError("unrecognized type")
|
||||
|
||||
|
||||
def is_generation_model(model_architectures):
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def decode_video_base64(video_base64):
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from sglang.srt.server import Runtime
|
||||
from sglang.srt.utils import is_generation_model
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
"The capital of France is",
|
||||
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
|
||||
|
||||
def is_embedding_model(model_path):
|
||||
# FIXME incomplete list
|
||||
if "e5-mistral-7b-instruct" in model_path.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_dtype_str(torch_dtype):
|
||||
if torch_dtype is torch.float16:
|
||||
return "float16"
|
||||
@@ -60,7 +54,7 @@ class HFRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
is_embedding_model=None,
|
||||
is_generation_model=None,
|
||||
):
|
||||
self.in_queue = multiprocessing.Queue()
|
||||
self.out_queue = multiprocessing.Queue()
|
||||
@@ -72,13 +66,13 @@ class HFRunner:
|
||||
self.out_queue,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_embedding_model,
|
||||
is_generation_model,
|
||||
),
|
||||
)
|
||||
self.model_proc.start()
|
||||
|
||||
def start_model_process(
|
||||
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
|
||||
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
@@ -86,12 +80,12 @@ class HFRunner:
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.is_embedding_model = (
|
||||
is_embedding_model(model_path)
|
||||
if is_embedding_model is None
|
||||
else is_embedding_model
|
||||
self.is_generation_model = (
|
||||
is_generation_model(model_path)
|
||||
if is_generation_model is None
|
||||
else is_generation_model
|
||||
)
|
||||
if not self.is_embedding_model:
|
||||
if self.is_generation_model:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -103,13 +97,13 @@ class HFRunner:
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_path,
|
||||
device="cpu",
|
||||
).to(dtype=torch_dtype)
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
)
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
if prompts is not None:
|
||||
if not self.is_embedding_model:
|
||||
if self.is_generation_model:
|
||||
output_strs = []
|
||||
prefill_logprobs = []
|
||||
for p in prompts:
|
||||
@@ -144,7 +138,6 @@ class HFRunner:
|
||||
)
|
||||
|
||||
else:
|
||||
assert isinstance(prompts, List[str])
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
@@ -175,16 +168,13 @@ class SRTRunner:
|
||||
model_path,
|
||||
tp_size=1,
|
||||
torch_dtype=torch.float16,
|
||||
is_embedding_model=None,
|
||||
is_generation_model=None,
|
||||
):
|
||||
self.is_embedding_model = (
|
||||
is_embedding_model(model_path)
|
||||
if is_embedding_model is None
|
||||
else is_embedding_model
|
||||
self.is_generation_model = (
|
||||
is_generation_model(model_path)
|
||||
if is_generation_model is None
|
||||
else is_generation_model
|
||||
)
|
||||
if self.is_embedding_model:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
@@ -196,38 +186,45 @@ class SRTRunner:
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=64,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs.append(response["text"])
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
if self.is_generation_model:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs.append(response["text"])
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
]
|
||||
]
|
||||
]
|
||||
]
|
||||
)
|
||||
# print(response["meta_info"]["output_top_logprobs"][0])
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||
)
|
||||
return ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||
)
|
||||
else:
|
||||
logits = []
|
||||
for prompt in prompts:
|
||||
response = self.runtime.encode(prompt)
|
||||
response = json.loads(response)
|
||||
logits.append(response["embedding"])
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.openai import OpenAI
|
||||
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
||||
|
||||
return 0 if success else -1
|
||||
|
||||
|
||||
def get_similarities(vec1, vec2):
|
||||
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
|
||||
|
||||
Reference in New Issue
Block a user