Add io struct for embedding models [unreachable code] - step 2/3 (#987)
This commit is contained in:
@@ -22,6 +22,8 @@ import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
@@ -166,6 +168,56 @@ class TokenizedGenerateReqInput:
|
||||
stream: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
|
||||
def post_init(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
):
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
|
||||
if self.text is not None:
|
||||
is_single = isinstance(self.text, str)
|
||||
else:
|
||||
is_single = isinstance(self.input_ids[0], int)
|
||||
self.is_single = is_single
|
||||
|
||||
if is_single:
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
self.sampling_params = {"max_new_tokens": 0}
|
||||
else:
|
||||
# support select operation
|
||||
self.batch_size = (
|
||||
len(self.text) if self.text is not None else len(self.input_ids)
|
||||
)
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
||||
else:
|
||||
if not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list.")
|
||||
self.sampling_params = [
|
||||
{"max_new_tokens": 0} for _ in range(self.batch_size)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedEmbeddingReqInput:
|
||||
rid: str
|
||||
input_text: str
|
||||
input_ids: List[int]
|
||||
sampling_params: SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
@@ -187,6 +239,14 @@ class BatchStrOut:
|
||||
finished_reason: List[BaseFinishReason]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchEmbeddingOut:
|
||||
rids: List[str]
|
||||
embeddings: List[List[float]]
|
||||
meta_info: List[Dict]
|
||||
finished_reason: List[BaseFinishReason]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlushCacheReq:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user