Feat/refactor embedding server (#7322)

This commit is contained in:
woodx
2025-06-20 14:46:01 +08:00
committed by GitHub
parent a06912ad8b
commit 4df5fc2156
4 changed files with 75 additions and 119 deletions

View File

@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
add_prometheus_middleware,
@@ -64,6 +65,7 @@ class AppState:
server_args: Optional[ServerArgs] = None
tokenizer_manager: Optional[TokenizerManager] = None
scheduler_info: Optional[Dict] = None
embedding_server: Optional[OpenAIServingEmbedding] = None
@asynccontextmanager
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
app.state.tokenizer_manager = tokenizer_manager
app.state.scheduler_info = scheduler_info
app.state.serving_embedding = OpenAIServingEmbedding(
tokenizer_manager=tokenizer_manager
)
if server_args.enable_metrics:
add_prometheus_middleware(app)
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
pass
try:
request_json = await raw_request.json()
request = EmbeddingRequest(**request_json)
except Exception as e:
return app.state.serving_embedding.create_error_response(
f"Invalid request body, error: {str(e)}"
)
ret = await app.state.serving_embedding.handle_request(request, raw_request)
return ret
@app.post("/v1/score")

View File

@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request(
[request], [self._generate_request_id_base(request)]
request, self._generate_request_id_base(request)
)
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
@abstractmethod
def _convert_to_internal_request(
self,
all_requests: List[OpenAIServingRequest],
request_ids: List[str],
request: OpenAIServingRequest,
request_id: str,
) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]:

View File

@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def _convert_to_internal_request(
self,
all_requests: List[EmbeddingRequest],
request_ids: List[str],
request: EmbeddingRequest,
request_id: str,
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
"""Convert OpenAI embedding request to internal format"""
prompts = [request.input for request in all_requests]
# Handle single vs multiple requests
if len(all_requests) == 1:
prompt = prompts[0]
if isinstance(prompt, str):
# Single string input
prompt = request.input
if isinstance(prompt, str):
# Single string input
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings
prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(
prompt[0], MultimodalEmbeddingInput
):
# Handle multimodal embedding inputs
texts = []
images = []
for item in prompt:
# Use padding for text if None - this could be improved
texts.append(item.text if item.text is not None else "padding")
images.append(item.image if item.image is not None else None)
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs
texts = []
images = []
for item in prompt:
# Use padding for text if None - this could be improved
texts.append(item.text if item.text is not None else "padding")
images.append(item.image if item.image is not None else None)
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
if chat_template_name is not None:
convs = generate_embedding_convs(
texts, images, chat_template_name
)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
generate_prompts = texts
if len(generate_prompts) == 1:
prompt_kwargs = {
"text": generate_prompts[0],
"image_data": images[0],
}
else:
prompt_kwargs = {
"text": generate_prompts,
"image_data": images,
}
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
if chat_template_name is not None:
convs = generate_embedding_convs(texts, images, chat_template_name)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
# List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt}
generate_prompts = texts
if len(generate_prompts) == 1:
prompt_kwargs = {
"text": generate_prompts[0],
"image_data": images[0],
}
else:
prompt_kwargs = {
"text": generate_prompts,
"image_data": images,
}
else:
# Other types (should not happen but handle gracefully)
# List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt}
# Use the passed request_ids for single request
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
else:
# Handle batch requests
if len(prompts) > 0:
# Validate that all prompts have the same type
first_prompt = prompts[0]
first_type = type(first_prompt)
for i, prompt in enumerate(prompts[1:], 1):
if type(prompt) != first_type:
raise AssertionError(
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
)
if isinstance(first_prompt, str):
# Batch of strings
prompt_kwargs = {"text": prompts}
elif isinstance(first_prompt, list):
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
# Batch of lists of strings
prompt_kwargs = {"text": prompts}
elif len(first_prompt) > 0 and isinstance(
first_prompt[0], MultimodalEmbeddingInput
):
# Handle multimodal batch requests
raise NotImplementedError(
"Multiple requests with multimodal inputs are not supported yet"
)
else:
# Batch of token ID lists
prompt_kwargs = {"input_ids": prompts}
else:
# Other types
prompt_kwargs = {"input_ids": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
# Use the passed request_ids for batch requests
final_request_id = request_ids
# Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt}
adapted_request = EmbeddingReqInput(
rid=final_request_id,
**prompt_kwargs,
)
return adapted_request, (
all_requests[0] if len(all_requests) == 1 else all_requests
)
return adapted_request, request
async def _handle_non_streaming_request(
self,
@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
if not isinstance(ret, list):
ret = [ret]
response = self._build_embedding_response(
ret, self.tokenizer_manager.model_path
)
response = self._build_embedding_response(ret)
return response
def _build_embedding_response(
self, ret: List[Dict[str, Any]], model_path: str
) -> EmbeddingResponse:
def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
"""Build the embedding response"""
embedding_objects = []
prompt_tokens = 0
@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return EmbeddingResponse(
data=embedding_objects,
model=model_path,
model=self.tokenizer_manager.model_path,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,