Feat/refactor embedding server (#7322)
This commit is contained in:
@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_prometheus_middleware,
|
||||
@@ -64,6 +65,7 @@ class AppState:
|
||||
server_args: Optional[ServerArgs] = None
|
||||
tokenizer_manager: Optional[TokenizerManager] = None
|
||||
scheduler_info: Optional[Dict] = None
|
||||
embedding_server: Optional[OpenAIServingEmbedding] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||
app.state.tokenizer_manager = tokenizer_manager
|
||||
app.state.scheduler_info = scheduler_info
|
||||
app.state.serving_embedding = OpenAIServingEmbedding(
|
||||
tokenizer_manager=tokenizer_manager
|
||||
)
|
||||
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def openai_v1_embeddings(raw_request: Request):
|
||||
pass
|
||||
try:
|
||||
request_json = await raw_request.json()
|
||||
request = EmbeddingRequest(**request_json)
|
||||
except Exception as e:
|
||||
return app.state.serving_embedding.create_error_response(
|
||||
f"Invalid request body, error: {str(e)}"
|
||||
)
|
||||
|
||||
ret = await app.state.serving_embedding.handle_request(request, raw_request)
|
||||
return ret
|
||||
|
||||
|
||||
@app.post("/v1/score")
|
||||
|
||||
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
# Convert to internal format
|
||||
adapted_request, processed_request = self._convert_to_internal_request(
|
||||
[request], [self._generate_request_id_base(request)]
|
||||
request, self._generate_request_id_base(request)
|
||||
)
|
||||
|
||||
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
||||
@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
|
||||
@abstractmethod
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[OpenAIServingRequest],
|
||||
request_ids: List[str],
|
||||
request: OpenAIServingRequest,
|
||||
request_id: str,
|
||||
) -> tuple[
|
||||
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
|
||||
]:
|
||||
|
||||
@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[EmbeddingRequest],
|
||||
request_ids: List[str],
|
||||
request: EmbeddingRequest,
|
||||
request_id: str,
|
||||
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
|
||||
"""Convert OpenAI embedding request to internal format"""
|
||||
prompts = [request.input for request in all_requests]
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if len(all_requests) == 1:
|
||||
prompt = prompts[0]
|
||||
if isinstance(prompt, str):
|
||||
# Single string input
|
||||
prompt = request.input
|
||||
if isinstance(prompt, str):
|
||||
# Single string input
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif isinstance(prompt, list):
|
||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||
# List of strings
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif isinstance(prompt, list):
|
||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||
# List of strings
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif len(prompt) > 0 and isinstance(
|
||||
prompt[0], MultimodalEmbeddingInput
|
||||
):
|
||||
# Handle multimodal embedding inputs
|
||||
texts = []
|
||||
images = []
|
||||
for item in prompt:
|
||||
# Use padding for text if None - this could be improved
|
||||
texts.append(item.text if item.text is not None else "padding")
|
||||
images.append(item.image if item.image is not None else None)
|
||||
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
||||
# Handle multimodal embedding inputs
|
||||
texts = []
|
||||
images = []
|
||||
for item in prompt:
|
||||
# Use padding for text if None - this could be improved
|
||||
texts.append(item.text if item.text is not None else "padding")
|
||||
images.append(item.image if item.image is not None else None)
|
||||
|
||||
generate_prompts = []
|
||||
# Check if we have a chat template for multimodal embeddings
|
||||
# This would need to be passed in from the server configuration
|
||||
chat_template_name = getattr(
|
||||
self.tokenizer_manager, "chat_template_name", None
|
||||
)
|
||||
if chat_template_name is not None:
|
||||
convs = generate_embedding_convs(
|
||||
texts, images, chat_template_name
|
||||
)
|
||||
for conv in convs:
|
||||
generate_prompts.append(conv.get_prompt())
|
||||
else:
|
||||
generate_prompts = texts
|
||||
|
||||
if len(generate_prompts) == 1:
|
||||
prompt_kwargs = {
|
||||
"text": generate_prompts[0],
|
||||
"image_data": images[0],
|
||||
}
|
||||
else:
|
||||
prompt_kwargs = {
|
||||
"text": generate_prompts,
|
||||
"image_data": images,
|
||||
}
|
||||
generate_prompts = []
|
||||
# Check if we have a chat template for multimodal embeddings
|
||||
# This would need to be passed in from the server configuration
|
||||
chat_template_name = getattr(
|
||||
self.tokenizer_manager, "chat_template_name", None
|
||||
)
|
||||
if chat_template_name is not None:
|
||||
convs = generate_embedding_convs(texts, images, chat_template_name)
|
||||
for conv in convs:
|
||||
generate_prompts.append(conv.get_prompt())
|
||||
else:
|
||||
# List of integers (token IDs) or empty list
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
generate_prompts = texts
|
||||
|
||||
if len(generate_prompts) == 1:
|
||||
prompt_kwargs = {
|
||||
"text": generate_prompts[0],
|
||||
"image_data": images[0],
|
||||
}
|
||||
else:
|
||||
prompt_kwargs = {
|
||||
"text": generate_prompts,
|
||||
"image_data": images,
|
||||
}
|
||||
else:
|
||||
# Other types (should not happen but handle gracefully)
|
||||
# List of integers (token IDs) or empty list
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
# Use the passed request_ids for single request
|
||||
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
|
||||
else:
|
||||
# Handle batch requests
|
||||
if len(prompts) > 0:
|
||||
# Validate that all prompts have the same type
|
||||
first_prompt = prompts[0]
|
||||
first_type = type(first_prompt)
|
||||
for i, prompt in enumerate(prompts[1:], 1):
|
||||
if type(prompt) != first_type:
|
||||
raise AssertionError(
|
||||
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
|
||||
)
|
||||
|
||||
if isinstance(first_prompt, str):
|
||||
# Batch of strings
|
||||
prompt_kwargs = {"text": prompts}
|
||||
elif isinstance(first_prompt, list):
|
||||
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
|
||||
# Batch of lists of strings
|
||||
prompt_kwargs = {"text": prompts}
|
||||
elif len(first_prompt) > 0 and isinstance(
|
||||
first_prompt[0], MultimodalEmbeddingInput
|
||||
):
|
||||
# Handle multimodal batch requests
|
||||
raise NotImplementedError(
|
||||
"Multiple requests with multimodal inputs are not supported yet"
|
||||
)
|
||||
else:
|
||||
# Batch of token ID lists
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
else:
|
||||
# Other types
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
# Use the passed request_ids for batch requests
|
||||
final_request_id = request_ids
|
||||
|
||||
# Other types (should not happen but handle gracefully)
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
adapted_request = EmbeddingReqInput(
|
||||
rid=final_request_id,
|
||||
**prompt_kwargs,
|
||||
)
|
||||
|
||||
return adapted_request, (
|
||||
all_requests[0] if len(all_requests) == 1 else all_requests
|
||||
)
|
||||
return adapted_request, request
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = self._build_embedding_response(
|
||||
ret, self.tokenizer_manager.model_path
|
||||
)
|
||||
response = self._build_embedding_response(ret)
|
||||
return response
|
||||
|
||||
def _build_embedding_response(
|
||||
self, ret: List[Dict[str, Any]], model_path: str
|
||||
) -> EmbeddingResponse:
|
||||
def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
|
||||
"""Build the embedding response"""
|
||||
embedding_objects = []
|
||||
prompt_tokens = 0
|
||||
@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=embedding_objects,
|
||||
model=model_path,
|
||||
model=self.tokenizer_manager.model_path,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=prompt_tokens,
|
||||
|
||||
Reference in New Issue
Block a user