Improve the code style: more comments and remove useless packages (#1139)
This commit is contained in:
@@ -17,7 +17,6 @@ limitations under the License.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
@@ -126,8 +125,6 @@ class DetokenizerManager:
|
|||||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trim stop str
|
|
||||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
|
||||||
output_strs = []
|
output_strs = []
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
s = self.decode_status[recv_obj.rids[i]]
|
s = self.decode_status[recv_obj.rids[i]]
|
||||||
@@ -144,6 +141,7 @@ class DetokenizerManager:
|
|||||||
|
|
||||||
output_strs.append(s.decoded_text + new_text)
|
output_strs.append(s.decoded_text + new_text)
|
||||||
|
|
||||||
|
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||||
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
||||||
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
||||||
if pos != -1:
|
if pos != -1:
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ import uuid
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
|
|
||||||
@@ -43,9 +41,9 @@ class GenerateReqInput:
|
|||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Whether to return logprobs.
|
# Whether to return logprobs.
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# The start location of the prompt for return_logprob.
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
# The number of top logprobs to return.
|
# If return logprobs, the number of top logprobs to return at each position.
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
# Whether to detokenize tokens in text in the returned logprobs.
|
# Whether to detokenize tokens in text in the returned logprobs.
|
||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
@@ -155,16 +153,27 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenizedGenerateReqInput:
|
class TokenizedGenerateReqInput:
|
||||||
|
# The request id
|
||||||
rid: str
|
rid: str
|
||||||
|
# The input text
|
||||||
input_text: str
|
input_text: str
|
||||||
|
# The input token ids
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
|
# The pixel values for input images
|
||||||
pixel_values: List[float]
|
pixel_values: List[float]
|
||||||
|
# The hash of input images
|
||||||
image_hash: int
|
image_hash: int
|
||||||
|
# The image size
|
||||||
image_size: List[int]
|
image_size: List[int]
|
||||||
|
# The sampling parameters
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
# Whether to return the logprobs
|
||||||
return_logprob: bool
|
return_logprob: bool
|
||||||
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||||
logprob_start_len: int
|
logprob_start_len: int
|
||||||
|
# If return logprobs, the number of top logprobs to return at each position.
|
||||||
top_logprobs_num: int
|
top_logprobs_num: int
|
||||||
|
# Whether to stream output
|
||||||
stream: bool
|
stream: bool
|
||||||
|
|
||||||
|
|
||||||
@@ -215,15 +224,21 @@ class EmbeddingReqInput:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenizedEmbeddingReqInput:
|
class TokenizedEmbeddingReqInput:
|
||||||
|
# The request id
|
||||||
rid: str
|
rid: str
|
||||||
|
# The input text
|
||||||
input_text: str
|
input_text: str
|
||||||
|
# The input token ids
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
|
# The request id
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
# The version id to sync decode status with in detokenizer_manager
|
||||||
vids: List[int]
|
vids: List[int]
|
||||||
decoded_texts: List[str]
|
decoded_texts: List[str]
|
||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
@@ -236,17 +251,25 @@ class BatchTokenIDOut:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchStrOut:
|
class BatchStrOut:
|
||||||
|
# The request id
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
# The output decoded strings
|
||||||
output_strs: List[str]
|
output_strs: List[str]
|
||||||
|
# The meta info
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
|
# The finish reason
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchEmbeddingOut:
|
class BatchEmbeddingOut:
|
||||||
|
# The request id
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
# The output embedding
|
||||||
embeddings: List[List[float]]
|
embeddings: List[List[float]]
|
||||||
|
# The meta info
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
|
# The finish reason
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
|
|
||||||
|
|
||||||
@@ -257,9 +280,5 @@ class FlushCacheReq:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AbortReq:
|
class AbortReq:
|
||||||
|
# The request id
|
||||||
rid: str
|
rid: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DetokenizeReqInput:
|
|
||||||
input_ids: List[int]
|
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ from typing import Dict, List, Optional, Union
|
|||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import psutil
|
|
||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
|
|||||||
Reference in New Issue
Block a user