Add e5-mistral embedding model - step 3/3 (#988)
This commit is contained in:
1
.github/workflows/unit-test.yml
vendored
1
.github/workflows/unit-test.yml
vendored
@@ -35,6 +35,7 @@ jobs:
|
|||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||||
pip install accelerate
|
pip install accelerate
|
||||||
|
pip install sentence_transformers
|
||||||
|
|
||||||
- name: Test Frontend Language
|
- name: Test Frontend Language
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -25,7 +25,11 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
from sglang.srt.managers.io_struct import (
|
||||||
|
BatchEmbeddingOut,
|
||||||
|
BatchStrOut,
|
||||||
|
BatchTokenIDOut,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
||||||
@@ -66,6 +70,18 @@ class DetokenizerManager:
|
|||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
while True:
|
while True:
|
||||||
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||||
|
|
||||||
|
if isinstance(recv_obj, BatchEmbeddingOut):
|
||||||
|
self.send_to_tokenizer.send_pyobj(
|
||||||
|
BatchEmbeddingOut(
|
||||||
|
rids=recv_obj.rids,
|
||||||
|
embeddings=recv_obj.embeddings,
|
||||||
|
meta_info=recv_obj.meta_info,
|
||||||
|
finished_reason=recv_obj.finished_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||||
bs = len(recv_obj.rids)
|
bs = len(recv_obj.rids)
|
||||||
|
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class Req:
|
|||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
|
self.embedding = None
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
self.normalized_prompt_logprob = None
|
self.normalized_prompt_logprob = None
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import dataclasses
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import transformers
|
import transformers
|
||||||
@@ -38,16 +38,19 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
|
BatchEmbeddingOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import is_multimodal_model, load_image
|
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
@@ -85,6 +88,7 @@ class TokenizerManager:
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
model_overide_args=model_overide_args,
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
|
self.is_generation = is_generation_model(self.hf_config.architectures)
|
||||||
|
|
||||||
if server_args.context_length is not None:
|
if server_args.context_length is not None:
|
||||||
self.context_len = server_args.context_length
|
self.context_len = server_args.context_length
|
||||||
@@ -133,7 +137,9 @@ class TokenizerManager:
|
|||||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_request(self, obj: GenerateReqInput, request=None):
|
async def generate_request(
|
||||||
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
|
||||||
|
):
|
||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
|
|
||||||
@@ -144,6 +150,8 @@ class TokenizerManager:
|
|||||||
async for response in self._handle_single_request(obj, request):
|
async for response in self._handle_single_request(obj, request):
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
|
if isinstance(obj, EmbeddingReqInput):
|
||||||
|
raise NotImplementedError("Please send only one prompt in each request")
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
raise ValueError("Do not support stream for batch mode.")
|
raise ValueError("Do not support stream for batch mode.")
|
||||||
|
|
||||||
@@ -151,39 +159,47 @@ class TokenizerManager:
|
|||||||
yield response
|
yield response
|
||||||
|
|
||||||
async def _handle_single_request(
|
async def _handle_single_request(
|
||||||
self, obj, request, index=None, is_cache_for_prefill=False
|
self,
|
||||||
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
request,
|
||||||
|
index=None,
|
||||||
|
is_cache_for_prefill=False,
|
||||||
):
|
):
|
||||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||||
not_use_index = index is None
|
not_use_index = index is None
|
||||||
|
|
||||||
rid = obj.rid if not_use_index else obj.rid[index]
|
rid = obj.rid if not_use_index else obj.rid[index]
|
||||||
input_text = obj.text if not_use_index else obj.text[index]
|
input_text = obj.text if not_use_index else obj.text[index]
|
||||||
input_ids = (
|
if obj.input_ids is None:
|
||||||
self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
if obj.input_ids is None
|
else:
|
||||||
else obj.input_ids
|
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
||||||
)
|
|
||||||
if not not_use_index and obj.input_ids:
|
|
||||||
input_ids = obj.input_ids[index]
|
|
||||||
|
|
||||||
self._validate_input_length(input_ids)
|
self._validate_input_length(input_ids)
|
||||||
|
|
||||||
sampling_params = self._get_sampling_params(
|
sampling_params = self._get_sampling_params(
|
||||||
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
||||||
)
|
)
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
|
||||||
obj.image_data if not_use_index else obj.image_data[index]
|
if self.is_generation:
|
||||||
)
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
return_logprob = (
|
obj.image_data if not_use_index else obj.image_data[index]
|
||||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
)
|
||||||
)
|
return_logprob = (
|
||||||
logprob_start_len = (
|
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||||
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
|
)
|
||||||
)
|
logprob_start_len = (
|
||||||
top_logprobs_num = (
|
obj.logprob_start_len
|
||||||
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
if not_use_index
|
||||||
)
|
else obj.logprob_start_len[index]
|
||||||
|
)
|
||||||
|
top_logprobs_num = (
|
||||||
|
obj.top_logprobs_num
|
||||||
|
if not_use_index
|
||||||
|
else obj.top_logprobs_num[index]
|
||||||
|
)
|
||||||
else: # A prefill request to cache the common prompt for parallel sampling
|
else: # A prefill request to cache the common prompt for parallel sampling
|
||||||
|
assert self.is_generation
|
||||||
if obj.text is not None:
|
if obj.text is not None:
|
||||||
if isinstance(obj.text, list):
|
if isinstance(obj.text, list):
|
||||||
input_text = obj.text[index]
|
input_text = obj.text[index]
|
||||||
@@ -213,19 +229,28 @@ class TokenizerManager:
|
|||||||
logprob_start_len = obj.logprob_start_len[0]
|
logprob_start_len = obj.logprob_start_len[0]
|
||||||
top_logprobs_num = obj.top_logprobs_num[0]
|
top_logprobs_num = obj.top_logprobs_num[0]
|
||||||
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
if self.is_generation:
|
||||||
rid,
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
input_text,
|
rid,
|
||||||
input_ids,
|
input_text,
|
||||||
pixel_values,
|
input_ids,
|
||||||
image_hash,
|
pixel_values,
|
||||||
image_size,
|
image_hash,
|
||||||
sampling_params,
|
image_size,
|
||||||
return_logprob,
|
sampling_params,
|
||||||
logprob_start_len,
|
return_logprob,
|
||||||
top_logprobs_num,
|
logprob_start_len,
|
||||||
obj.stream,
|
top_logprobs_num,
|
||||||
)
|
obj.stream,
|
||||||
|
)
|
||||||
|
else: # is embedding
|
||||||
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
|
rid,
|
||||||
|
input_text,
|
||||||
|
input_ids,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
@@ -368,7 +393,7 @@ class TokenizerManager:
|
|||||||
self,
|
self,
|
||||||
event: asyncio.Event,
|
event: asyncio.Event,
|
||||||
state: ReqState,
|
state: ReqState,
|
||||||
obj: GenerateReqInput,
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
rid: str,
|
rid: str,
|
||||||
request,
|
request,
|
||||||
):
|
):
|
||||||
@@ -381,12 +406,15 @@ class TokenizerManager:
|
|||||||
raise ValueError(f"Abort request {rid}")
|
raise ValueError(f"Abort request {rid}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
out = self.convert_logprob_style(
|
if self.is_generation:
|
||||||
state.out_list[-1],
|
out = self.convert_logprob_style(
|
||||||
obj.return_logprob,
|
state.out_list[-1],
|
||||||
obj.top_logprobs_num,
|
obj.return_logprob,
|
||||||
obj.return_text_in_logprobs,
|
obj.top_logprobs_num,
|
||||||
)
|
obj.return_text_in_logprobs,
|
||||||
|
)
|
||||||
|
else: # isinstance(obj, EmbeddingReqInput)
|
||||||
|
out = state.out_list[-1]
|
||||||
|
|
||||||
# Log requests
|
# Log requests
|
||||||
if self.server_args.log_requests and state.finished:
|
if self.server_args.log_requests and state.finished:
|
||||||
@@ -459,8 +487,10 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
while True:
|
while True:
|
||||||
recv_obj: BatchStrOut = await self.recv_from_detokenizer.recv_pyobj()
|
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
|
||||||
assert isinstance(recv_obj, BatchStrOut)
|
await self.recv_from_detokenizer.recv_pyobj()
|
||||||
|
)
|
||||||
|
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
|
||||||
|
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
state = self.rid_to_state.get(rid, None)
|
state = self.rid_to_state.get(rid, None)
|
||||||
@@ -468,10 +498,17 @@ class TokenizerManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
recv_obj.meta_info[i]["id"] = rid
|
recv_obj.meta_info[i]["id"] = rid
|
||||||
out_dict = {
|
if isinstance(recv_obj, BatchStrOut):
|
||||||
"text": recv_obj.output_strs[i],
|
out_dict = {
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"text": recv_obj.output_strs[i],
|
||||||
}
|
"meta_info": recv_obj.meta_info[i],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
|
out_dict = {
|
||||||
|
"embedding": recv_obj.embeddings[i],
|
||||||
|
"meta_info": recv_obj.meta_info[i],
|
||||||
|
}
|
||||||
state.out_list.append(out_dict)
|
state.out_list.append(out_dict)
|
||||||
state.finished = recv_obj.finished_reason[i] is not None
|
state.finished = recv_obj.finished_reason[i] is not None
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import multiprocessing
|
|||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -31,8 +31,10 @@ from sglang.srt.constrained.jump_forward import JumpForwardCache
|
|||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
|
BatchEmbeddingOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
||||||
@@ -205,7 +207,9 @@ class ModelTpServer:
|
|||||||
try:
|
try:
|
||||||
# Recv requests
|
# Recv requests
|
||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
if isinstance(
|
||||||
|
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||||
|
):
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
@@ -297,41 +301,42 @@ class ModelTpServer:
|
|||||||
|
|
||||||
def handle_generate_request(
|
def handle_generate_request(
|
||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||||
):
|
):
|
||||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||||
req.pixel_values = recv_req.pixel_values
|
|
||||||
if req.pixel_values is not None:
|
|
||||||
req.pad_value = [
|
|
||||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
|
||||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
|
||||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
|
||||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
|
||||||
]
|
|
||||||
req.image_size = recv_req.image_size
|
|
||||||
(
|
|
||||||
req.origin_input_ids,
|
|
||||||
req.image_offset,
|
|
||||||
) = self.model_runner.model.pad_input_ids(
|
|
||||||
req.origin_input_ids_unpadded,
|
|
||||||
req.pad_value,
|
|
||||||
req.pixel_values.shape,
|
|
||||||
req.image_size,
|
|
||||||
)
|
|
||||||
req.sampling_params = recv_req.sampling_params
|
|
||||||
req.return_logprob = recv_req.return_logprob
|
|
||||||
req.logprob_start_len = recv_req.logprob_start_len
|
|
||||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
|
||||||
req.stream = recv_req.stream
|
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
req.sampling_params = recv_req.sampling_params
|
||||||
# Init regex fsm
|
if self.model_runner.is_generation:
|
||||||
if req.sampling_params.regex is not None:
|
req.pixel_values = recv_req.pixel_values
|
||||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
if req.pixel_values is not None:
|
||||||
if not self.disable_regex_jump_forward:
|
req.pad_value = [
|
||||||
req.jump_forward_map = self.jump_forward_cache.query(
|
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||||
req.sampling_params.regex
|
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||||
|
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||||
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||||
|
]
|
||||||
|
req.image_size = recv_req.image_size
|
||||||
|
(
|
||||||
|
req.origin_input_ids,
|
||||||
|
req.image_offset,
|
||||||
|
) = self.model_runner.model.pad_input_ids(
|
||||||
|
req.origin_input_ids_unpadded,
|
||||||
|
req.pad_value,
|
||||||
|
req.pixel_values.shape,
|
||||||
|
req.image_size,
|
||||||
)
|
)
|
||||||
|
req.return_logprob = recv_req.return_logprob
|
||||||
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||||
|
req.stream = recv_req.stream
|
||||||
|
|
||||||
|
# Init regex fsm
|
||||||
|
if req.sampling_params.regex is not None:
|
||||||
|
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||||
|
if not self.disable_regex_jump_forward:
|
||||||
|
req.jump_forward_map = self.jump_forward_cache.query(
|
||||||
|
req.sampling_params.regex
|
||||||
|
)
|
||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
@@ -340,14 +345,17 @@ class ModelTpServer:
|
|||||||
"the max context length. Truncated!!!"
|
"the max context length. Truncated!!!"
|
||||||
)
|
)
|
||||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||||
req.sampling_params.max_new_tokens = min(
|
|
||||||
(
|
if self.model_runner.is_generation:
|
||||||
req.sampling_params.max_new_tokens
|
req.sampling_params.max_new_tokens = min(
|
||||||
if req.sampling_params.max_new_tokens is not None
|
(
|
||||||
else 1 << 30
|
req.sampling_params.max_new_tokens
|
||||||
),
|
if req.sampling_params.max_new_tokens is not None
|
||||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
else 1 << 30
|
||||||
)
|
),
|
||||||
|
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||||
|
)
|
||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
||||||
@@ -439,47 +447,68 @@ class ModelTpServer:
|
|||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
# Forward and sample the next tokens
|
if self.model_runner.is_generation:
|
||||||
if batch.extend_num_tokens != 0:
|
# Forward and sample the next tokens
|
||||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
if batch.extend_num_tokens != 0:
|
||||||
next_token_ids = batch.sample(output.next_token_logits)
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
|
next_token_ids = batch.sample(output.next_token_logits)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
output.next_token_logprobs = output.next_token_logprobs[
|
output.next_token_logprobs = output.next_token_logprobs[
|
||||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
||||||
output.normalized_prompt_logprobs = (
|
output.normalized_prompt_logprobs = (
|
||||||
output.normalized_prompt_logprobs.tolist()
|
output.normalized_prompt_logprobs.tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
else:
|
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
|
||||||
|
|
||||||
# Check finish conditions
|
|
||||||
pt = 0
|
|
||||||
for i, req in enumerate(batch.reqs):
|
|
||||||
if req is not self.current_inflight_req:
|
|
||||||
# Inflight reqs' prefill is not finished
|
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
|
||||||
req.output_ids.append(next_token_ids[i])
|
|
||||||
req.check_finished()
|
|
||||||
|
|
||||||
if req.finished():
|
|
||||||
self.tree_cache.cache_finished_req(req)
|
|
||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
|
|
||||||
if req is self.current_inflight_req:
|
# Check finish conditions
|
||||||
# Inflight request would get a new req idx
|
pt = 0
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
for i, req in enumerate(batch.reqs):
|
||||||
|
if req is not self.current_inflight_req:
|
||||||
|
# Inflight reqs' prefill is not finished
|
||||||
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
|
req.output_ids.append(next_token_ids[i])
|
||||||
|
req.check_finished()
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.finished():
|
||||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
self.tree_cache.cache_finished_req(req)
|
||||||
pt += req.extend_input_len
|
else:
|
||||||
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
|
if req is self.current_inflight_req:
|
||||||
|
# Inflight request would get a new req idx
|
||||||
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
|
if req.return_logprob:
|
||||||
|
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||||
|
pt += req.extend_input_len
|
||||||
|
else:
|
||||||
|
assert batch.extend_num_tokens != 0
|
||||||
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
|
embeddings = output.embeddings.tolist()
|
||||||
|
|
||||||
|
# Check finish conditions
|
||||||
|
for i, req in enumerate(batch.reqs):
|
||||||
|
req.embedding = embeddings[i]
|
||||||
|
if req is not self.current_inflight_req:
|
||||||
|
# Inflight reqs' prefill is not finished
|
||||||
|
req.check_finished()
|
||||||
|
|
||||||
|
if req.finished():
|
||||||
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
else:
|
||||||
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
|
if req is self.current_inflight_req:
|
||||||
|
# Inflight request would get a new req idx
|
||||||
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
@@ -596,15 +625,19 @@ class ModelTpServer:
|
|||||||
|
|
||||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
output_vids = []
|
|
||||||
decoded_texts = []
|
|
||||||
output_read_ids = []
|
|
||||||
output_read_offsets = []
|
|
||||||
output_skip_special_tokens = []
|
|
||||||
output_spaces_between_special_tokens = []
|
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
output_finished_reason: List[BaseFinishReason] = []
|
output_finished_reason: List[BaseFinishReason] = []
|
||||||
|
if self.model_runner.is_generation:
|
||||||
|
output_vids = []
|
||||||
|
decoded_texts = []
|
||||||
|
output_read_ids = []
|
||||||
|
output_read_offsets = []
|
||||||
|
output_skip_special_tokens = []
|
||||||
|
output_spaces_between_special_tokens = []
|
||||||
|
else: # for embedding model
|
||||||
|
output_embeddings = []
|
||||||
unfinished_indices = []
|
unfinished_indices = []
|
||||||
|
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if not req.finished() and req is not self.current_inflight_req:
|
if not req.finished() and req is not self.current_inflight_req:
|
||||||
unfinished_indices.append(i)
|
unfinished_indices.append(i)
|
||||||
@@ -619,56 +652,73 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
output_vids.append(req.vid)
|
|
||||||
decoded_texts.append(req.decoded_text)
|
|
||||||
read_ids, read_offset = req.init_incremental_detokenize()
|
|
||||||
output_read_ids.append(read_ids)
|
|
||||||
output_read_offsets.append(read_offset)
|
|
||||||
output_skip_special_tokens.append(
|
|
||||||
req.sampling_params.skip_special_tokens
|
|
||||||
)
|
|
||||||
output_spaces_between_special_tokens.append(
|
|
||||||
req.sampling_params.spaces_between_special_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
meta_info = {
|
|
||||||
"prompt_tokens": len(req.origin_input_ids),
|
|
||||||
"completion_tokens": len(req.output_ids),
|
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
|
||||||
"finish_reason": str(req.finished_reason),
|
|
||||||
}
|
|
||||||
if req.return_logprob:
|
|
||||||
(
|
|
||||||
meta_info["input_token_logprobs"],
|
|
||||||
meta_info["output_token_logprobs"],
|
|
||||||
meta_info["input_top_logprobs"],
|
|
||||||
meta_info["output_top_logprobs"],
|
|
||||||
meta_info["normalized_prompt_logprob"],
|
|
||||||
) = (
|
|
||||||
req.input_token_logprobs,
|
|
||||||
req.output_token_logprobs,
|
|
||||||
req.input_top_logprobs,
|
|
||||||
req.output_top_logprobs,
|
|
||||||
req.normalized_prompt_logprob,
|
|
||||||
)
|
|
||||||
output_meta_info.append(meta_info)
|
|
||||||
output_finished_reason.append(req.finished_reason)
|
output_finished_reason.append(req.finished_reason)
|
||||||
|
if self.model_runner.is_generation:
|
||||||
|
output_vids.append(req.vid)
|
||||||
|
decoded_texts.append(req.decoded_text)
|
||||||
|
read_ids, read_offset = req.init_incremental_detokenize()
|
||||||
|
output_read_ids.append(read_ids)
|
||||||
|
output_read_offsets.append(read_offset)
|
||||||
|
output_skip_special_tokens.append(
|
||||||
|
req.sampling_params.skip_special_tokens
|
||||||
|
)
|
||||||
|
output_spaces_between_special_tokens.append(
|
||||||
|
req.sampling_params.spaces_between_special_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_info = {
|
||||||
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
|
"completion_tokens": len(req.output_ids),
|
||||||
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
|
"finish_reason": str(req.finished_reason),
|
||||||
|
}
|
||||||
|
if req.return_logprob:
|
||||||
|
(
|
||||||
|
meta_info["input_token_logprobs"],
|
||||||
|
meta_info["output_token_logprobs"],
|
||||||
|
meta_info["input_top_logprobs"],
|
||||||
|
meta_info["output_top_logprobs"],
|
||||||
|
meta_info["normalized_prompt_logprob"],
|
||||||
|
) = (
|
||||||
|
req.input_token_logprobs,
|
||||||
|
req.output_token_logprobs,
|
||||||
|
req.input_top_logprobs,
|
||||||
|
req.output_top_logprobs,
|
||||||
|
req.normalized_prompt_logprob,
|
||||||
|
)
|
||||||
|
output_meta_info.append(meta_info)
|
||||||
|
else: # for embedding model
|
||||||
|
output_embeddings.append(req.embedding)
|
||||||
|
meta_info = {
|
||||||
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
|
}
|
||||||
|
output_meta_info.append(meta_info)
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
if output_rids:
|
if output_rids:
|
||||||
self.out_pyobjs.append(
|
if self.model_runner.is_generation:
|
||||||
BatchTokenIDOut(
|
self.out_pyobjs.append(
|
||||||
output_rids,
|
BatchTokenIDOut(
|
||||||
output_vids,
|
output_rids,
|
||||||
decoded_texts,
|
output_vids,
|
||||||
output_read_ids,
|
decoded_texts,
|
||||||
output_read_offsets,
|
output_read_ids,
|
||||||
output_skip_special_tokens,
|
output_read_offsets,
|
||||||
output_spaces_between_special_tokens,
|
output_skip_special_tokens,
|
||||||
output_meta_info,
|
output_spaces_between_special_tokens,
|
||||||
output_finished_reason,
|
output_meta_info,
|
||||||
|
output_finished_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else: # for embedding model
|
||||||
|
self.out_pyobjs.append(
|
||||||
|
BatchEmbeddingOut(
|
||||||
|
output_rids,
|
||||||
|
output_embeddings,
|
||||||
|
output_meta_info,
|
||||||
|
output_finished_reason,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Remove finished reqs: update batch tensors
|
# Remove finished reqs: update batch tensors
|
||||||
batch.filter_batch(unfinished_indices)
|
batch.filter_batch(unfinished_indices)
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
|
is_generation_model,
|
||||||
is_llama3_405b_fp8,
|
is_llama3_405b_fp8,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
monkey_patch_vllm_dummy_weight_loader,
|
monkey_patch_vllm_dummy_weight_loader,
|
||||||
@@ -132,8 +133,10 @@ class ModelRunner:
|
|||||||
self.init_cublas()
|
self.init_cublas()
|
||||||
self.init_flashinfer()
|
self.init_flashinfer()
|
||||||
|
|
||||||
# Capture cuda graphs
|
if self.is_generation:
|
||||||
self.init_cuda_graphs()
|
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
||||||
|
# Capture cuda graphs
|
||||||
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -184,6 +187,10 @@ class ModelRunner:
|
|||||||
scheduler_config=None,
|
scheduler_config=None,
|
||||||
cache_config=None,
|
cache_config=None,
|
||||||
)
|
)
|
||||||
|
self.is_generation = is_generation_model(
|
||||||
|
self.model_config.hf_config.architectures
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[gpu={self.gpu_id}] Load weight end. "
|
f"[gpu={self.gpu_id}] Load weight end. "
|
||||||
f"type={type(self.model).__name__}, "
|
f"type={type(self.model).__name__}, "
|
||||||
@@ -406,8 +413,10 @@ def import_model_classes():
|
|||||||
entry, list
|
entry, list
|
||||||
): # To support multiple model classes in one module
|
): # To support multiple model classes in one module
|
||||||
for tmp in entry:
|
for tmp in entry:
|
||||||
|
assert tmp.__name__ not in model_arch_name_to_cls
|
||||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||||
else:
|
else:
|
||||||
|
assert entry.__name__ not in model_arch_name_to_cls
|
||||||
model_arch_name_to_cls[entry.__name__] = entry
|
model_arch_name_to_cls[entry.__name__] = entry
|
||||||
|
|
||||||
# compat: some models such as chatglm has incorrect class set in config.json
|
# compat: some models such as chatglm has incorrect class set in config.json
|
||||||
@@ -417,6 +426,7 @@ def import_model_classes():
|
|||||||
):
|
):
|
||||||
for remap in module.EntryClassRemapping:
|
for remap in module.EntryClassRemapping:
|
||||||
if isinstance(remap, tuple) and len(remap) == 2:
|
if isinstance(remap, tuple) and len(remap) == 2:
|
||||||
|
assert remap[0] not in model_arch_name_to_cls
|
||||||
model_arch_name_to_cls[remap[0]] = remap[1]
|
model_arch_name_to_cls[remap[0]] = remap[1]
|
||||||
|
|
||||||
return model_arch_name_to_cls
|
return model_arch_name_to_cls
|
||||||
|
|||||||
@@ -84,3 +84,5 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
EntryClass = LlamaEmbeddingModel
|
EntryClass = LlamaEmbeddingModel
|
||||||
|
# compat: e5-mistral model.config class == MistralModel
|
||||||
|
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
|
|||||||
start_controller_process as start_controller_process_single,
|
start_controller_process as start_controller_process_single,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.openai_api.adapter import (
|
from sglang.srt.openai_api.adapter import (
|
||||||
load_chat_template_for_openai_api,
|
load_chat_template_for_openai_api,
|
||||||
@@ -97,6 +97,7 @@ async def health() -> Response:
|
|||||||
async def get_model_info():
|
async def get_model_info():
|
||||||
result = {
|
result = {
|
||||||
"model_path": tokenizer_manager.model_path,
|
"model_path": tokenizer_manager.model_path,
|
||||||
|
"is_generation": tokenizer_manager.is_generation,
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
|
|||||||
app.put("/generate")(generate_request)
|
app.put("/generate")(generate_request)
|
||||||
|
|
||||||
|
|
||||||
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
|
"""Handle an embedding request."""
|
||||||
|
try:
|
||||||
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.post("/encode")(encode_request)
|
||||||
|
app.put("/encode")(encode_request)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def openai_v1_completions(raw_request: Request):
|
async def openai_v1_completions(raw_request: Request):
|
||||||
return await v1_completions(tokenizer_manager, raw_request)
|
return await v1_completions(tokenizer_manager, raw_request)
|
||||||
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
except (AssertionError, requests.exceptions.RequestException) as e:
|
except (AssertionError, requests.exceptions.RequestException) as e:
|
||||||
last_traceback = get_exception_traceback()
|
last_traceback = get_exception_traceback()
|
||||||
pass
|
pass
|
||||||
|
model_info = res.json()
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||||
|
max_new_tokens = 8 if model_info["is_generation"] else 0
|
||||||
try:
|
try:
|
||||||
for _ in range(server_args.dp_size):
|
for _ in range(server_args.dp_size):
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
url + "/generate",
|
url + request_name,
|
||||||
json={
|
json={
|
||||||
"text": "The capital city of France is",
|
"text": "The capital city of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 8,
|
"max_new_tokens": max_new_tokens,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -529,5 +548,18 @@ class Runtime:
|
|||||||
)
|
)
|
||||||
return json.dumps(response.json())
|
return json.dumps(response.json())
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
):
|
||||||
|
json_data = {
|
||||||
|
"text": prompt,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
self.url + "/encode",
|
||||||
|
json=json_data,
|
||||||
|
)
|
||||||
|
return json.dumps(response.json())
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|||||||
@@ -223,6 +223,15 @@ def is_multimodal_model(model):
|
|||||||
raise ValueError("unrecognized type")
|
raise ValueError("unrecognized type")
|
||||||
|
|
||||||
|
|
||||||
|
def is_generation_model(model_architectures):
|
||||||
|
if (
|
||||||
|
"LlamaEmbeddingModel" in model_architectures
|
||||||
|
or "MistralModel" in model_architectures
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def decode_video_base64(video_base64):
|
def decode_video_base64(video_base64):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import torch.nn.functional as F
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from sglang.srt.server import Runtime
|
from sglang.srt.server import Runtime
|
||||||
|
from sglang.srt.utils import is_generation_model
|
||||||
|
|
||||||
DEFAULT_PROMPTS = [
|
DEFAULT_PROMPTS = [
|
||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
|
|||||||
NUM_TOP_LOGPROBS = 5
|
NUM_TOP_LOGPROBS = 5
|
||||||
|
|
||||||
|
|
||||||
def is_embedding_model(model_path):
|
|
||||||
# FIXME incomplete list
|
|
||||||
if "e5-mistral-7b-instruct" in model_path.lower():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_dtype_str(torch_dtype):
|
def get_dtype_str(torch_dtype):
|
||||||
if torch_dtype is torch.float16:
|
if torch_dtype is torch.float16:
|
||||||
return "float16"
|
return "float16"
|
||||||
@@ -60,7 +54,7 @@ class HFRunner:
|
|||||||
self,
|
self,
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
is_embedding_model=None,
|
is_generation_model=None,
|
||||||
):
|
):
|
||||||
self.in_queue = multiprocessing.Queue()
|
self.in_queue = multiprocessing.Queue()
|
||||||
self.out_queue = multiprocessing.Queue()
|
self.out_queue = multiprocessing.Queue()
|
||||||
@@ -72,13 +66,13 @@ class HFRunner:
|
|||||||
self.out_queue,
|
self.out_queue,
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
is_embedding_model,
|
is_generation_model,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.model_proc.start()
|
self.model_proc.start()
|
||||||
|
|
||||||
def start_model_process(
|
def start_model_process(
|
||||||
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
|
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
||||||
):
|
):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
@@ -86,12 +80,12 @@ class HFRunner:
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_embedding_model = (
|
self.is_generation_model = (
|
||||||
is_embedding_model(model_path)
|
is_generation_model(model_path)
|
||||||
if is_embedding_model is None
|
if is_generation_model is None
|
||||||
else is_embedding_model
|
else is_generation_model
|
||||||
)
|
)
|
||||||
if not self.is_embedding_model:
|
if self.is_generation_model:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -103,13 +97,13 @@ class HFRunner:
|
|||||||
|
|
||||||
self.model = SentenceTransformer(
|
self.model = SentenceTransformer(
|
||||||
model_path,
|
model_path,
|
||||||
device="cpu",
|
model_kwargs={"torch_dtype": torch_dtype},
|
||||||
).to(dtype=torch_dtype)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
prompts, max_new_tokens = in_queue.get()
|
prompts, max_new_tokens = in_queue.get()
|
||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
if not self.is_embedding_model:
|
if self.is_generation_model:
|
||||||
output_strs = []
|
output_strs = []
|
||||||
prefill_logprobs = []
|
prefill_logprobs = []
|
||||||
for p in prompts:
|
for p in prompts:
|
||||||
@@ -144,7 +138,6 @@ class HFRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(prompts, List[str])
|
|
||||||
logits = self.model.encode(prompts).tolist()
|
logits = self.model.encode(prompts).tolist()
|
||||||
|
|
||||||
out_queue.put(ModelOutput(embed_logits=logits))
|
out_queue.put(ModelOutput(embed_logits=logits))
|
||||||
@@ -175,16 +168,13 @@ class SRTRunner:
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
is_embedding_model=None,
|
is_generation_model=None,
|
||||||
):
|
):
|
||||||
self.is_embedding_model = (
|
self.is_generation_model = (
|
||||||
is_embedding_model(model_path)
|
is_generation_model(model_path)
|
||||||
if is_embedding_model is None
|
if is_generation_model is None
|
||||||
else is_embedding_model
|
else is_generation_model
|
||||||
)
|
)
|
||||||
if self.is_embedding_model:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
self.runtime = Runtime(
|
self.runtime = Runtime(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
@@ -196,38 +186,45 @@ class SRTRunner:
|
|||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
max_new_tokens=64,
|
max_new_tokens=64,
|
||||||
):
|
):
|
||||||
# the return value contains logprobs from prefill
|
if self.is_generation_model:
|
||||||
output_strs = []
|
# the return value contains logprobs from prefill
|
||||||
top_input_logprobs = []
|
output_strs = []
|
||||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
top_input_logprobs = []
|
||||||
for prompt in prompts:
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||||
response = self.runtime.generate(
|
for prompt in prompts:
|
||||||
prompt,
|
response = self.runtime.generate(
|
||||||
sampling_params=sampling_params,
|
prompt,
|
||||||
return_logprob=True,
|
sampling_params=sampling_params,
|
||||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
return_logprob=True,
|
||||||
)
|
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||||
response = json.loads(response)
|
)
|
||||||
output_strs.append(response["text"])
|
response = json.loads(response)
|
||||||
top_input_logprobs.append(
|
output_strs.append(response["text"])
|
||||||
[
|
top_input_logprobs.append(
|
||||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
|
||||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
[
|
[
|
||||||
tup[0]
|
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||||
:NUM_TOP_LOGPROBS
|
]
|
||||||
|
+ [
|
||||||
|
[
|
||||||
|
tup[0]
|
||||||
|
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||||
|
:NUM_TOP_LOGPROBS
|
||||||
|
]
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
]
|
)
|
||||||
)
|
|
||||||
# print(response["meta_info"]["output_top_logprobs"][0])
|
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logits = []
|
||||||
|
for prompt in prompts:
|
||||||
|
response = self.runtime.encode(prompt)
|
||||||
|
response = json.loads(response)
|
||||||
|
logits.append(response["embedding"])
|
||||||
|
return ModelOutput(embed_logits=logits)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.lang.backend.openai import OpenAI
|
from sglang.lang.backend.openai import OpenAI
|
||||||
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
|||||||
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
||||||
|
|
||||||
return 0 if success else -1
|
return 0 if success else -1
|
||||||
|
|
||||||
|
|
||||||
|
def get_similarities(vec1, vec2):
|
||||||
|
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
|
||||||
|
|||||||
69
test/srt/models/test_embedding_models.py
Normal file
69
test/srt/models/test_embedding_models.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||||
|
from sglang.test.test_utils import get_similarities
|
||||||
|
|
||||||
|
MODELS = [("intfloat/e5-mistral-7b-instruct", 1)]
|
||||||
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingModels(unittest.TestCase):
|
||||||
|
|
||||||
|
def assert_close_prefill_logits(
|
||||||
|
self,
|
||||||
|
prompts,
|
||||||
|
model_path,
|
||||||
|
tp_size,
|
||||||
|
torch_dtype,
|
||||||
|
) -> None:
|
||||||
|
with HFRunner(
|
||||||
|
model_path, torch_dtype=torch_dtype, is_generation_model=False
|
||||||
|
) as hf_runner:
|
||||||
|
hf_outputs = hf_runner.forward(prompts)
|
||||||
|
|
||||||
|
with SRTRunner(
|
||||||
|
model_path,
|
||||||
|
tp_size=tp_size,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
is_generation_model=False,
|
||||||
|
) as srt_runner:
|
||||||
|
srt_outputs = srt_runner.forward(prompts)
|
||||||
|
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
||||||
|
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
||||||
|
|
||||||
|
similarities = torch.tensor(get_similarities(hf_logits, srt_logits))
|
||||||
|
|
||||||
|
tolerance = 1e-2
|
||||||
|
assert torch.all(
|
||||||
|
abs(similarities - 1) < tolerance
|
||||||
|
), f"embeddings not all close"
|
||||||
|
|
||||||
|
def test_prefill_logits(self):
|
||||||
|
for model, tp_size in MODELS:
|
||||||
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
self.assert_close_prefill_logits(
|
||||||
|
DEFAULT_PROMPTS, model, tp_size, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
@@ -3,7 +3,9 @@ Copyright 2023-2024 SGLang Team
|
|||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
@@ -33,7 +35,7 @@ class TestCausalModels(unittest.TestCase):
|
|||||||
torch_dtype,
|
torch_dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
with HFRunner(
|
with HFRunner(
|
||||||
model_path, torch_dtype=torch_dtype, is_embedding_model=False
|
model_path, torch_dtype=torch_dtype, is_generation_model=True
|
||||||
) as hf_runner:
|
) as hf_runner:
|
||||||
hf_outputs = hf_runner.forward(prompts)
|
hf_outputs = hf_runner.forward(prompts)
|
||||||
|
|
||||||
@@ -41,7 +43,7 @@ class TestCausalModels(unittest.TestCase):
|
|||||||
model_path,
|
model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
is_embedding_model=False,
|
is_generation_model=True,
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(prompts)
|
srt_outputs = srt_runner.forward(prompts)
|
||||||
|
|
||||||
@@ -10,7 +10,8 @@ suites = {
|
|||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_torch_compile.py",
|
"test_torch_compile.py",
|
||||||
"models/test_causal_models.py",
|
"models/test_generation_models.py",
|
||||||
|
"models/test_embedding_models.py",
|
||||||
"sampling/penaltylib",
|
"sampling/penaltylib",
|
||||||
],
|
],
|
||||||
"sampling/penaltylib": glob.glob(
|
"sampling/penaltylib": glob.glob(
|
||||||
|
|||||||
Reference in New Issue
Block a user