2024-11-22 22:16:53 +08:00
|
|
|
# Copyright 2023-2024 SGLang Team
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ==============================================================================
|
2024-06-08 02:06:52 -07:00
|
|
|
"""
|
|
|
|
|
The definition of objects transfered between different
|
|
|
|
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
|
|
|
|
"""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
import uuid
|
2024-09-15 06:36:06 -07:00
|
|
|
from dataclasses import dataclass
|
2024-10-11 17:34:25 +08:00
|
|
|
from enum import Enum
|
2024-11-25 12:32:51 -08:00
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-12-29 05:30:27 +08:00
|
|
|
import torch
|
|
|
|
|
|
2024-07-29 23:04:48 -07:00
|
|
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
2024-08-21 16:48:24 -07:00
|
|
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class GenerateReqInput:
|
2024-07-19 10:58:03 -07:00
|
|
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
2024-07-20 03:11:15 -07:00
|
|
|
text: Optional[Union[List[str], str]] = None
|
2024-11-25 19:35:04 -05:00
|
|
|
# The token ids for text; one can specify either text or input_ids
|
2024-05-12 12:29:00 -10:00
|
|
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
2024-11-25 19:35:04 -05:00
|
|
|
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
|
|
|
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
|
|
|
|
# See also python/sglang/srt/utils.py:load_image.
|
2024-01-08 04:37:50 +00:00
|
|
|
image_data: Optional[Union[List[str], str]] = None
|
2024-07-27 19:50:34 -07:00
|
|
|
# The sampling_params. See descriptions below.
|
2024-10-01 10:25:32 -07:00
|
|
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# The request id.
|
2024-01-08 04:37:50 +00:00
|
|
|
rid: Optional[Union[List[str], str]] = None
|
2024-07-19 10:58:03 -07:00
|
|
|
# Whether to return logprobs.
|
2024-01-23 05:07:30 -08:00
|
|
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
2024-08-17 14:37:52 -07:00
|
|
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
2024-09-15 06:36:06 -07:00
|
|
|
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
2024-01-23 05:07:30 -08:00
|
|
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
2024-08-17 14:37:52 -07:00
|
|
|
# If return logprobs, the number of top logprobs to return at each position.
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
2024-07-27 19:50:34 -07:00
|
|
|
# Whether to detokenize tokens in text in the returned logprobs.
|
2024-02-15 10:54:20 -08:00
|
|
|
return_text_in_logprobs: bool = False
|
2024-07-19 10:58:03 -07:00
|
|
|
# Whether to stream output.
|
2024-01-08 04:37:50 +00:00
|
|
|
stream: bool = False
|
2024-09-09 17:07:34 +08:00
|
|
|
# The modalities of the image data [image, multi-images, video]
|
|
|
|
|
modalities: Optional[List[str]] = None
|
2024-09-12 16:46:14 -07:00
|
|
|
# LoRA related
|
|
|
|
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
|
|
|
|
|
2024-11-20 00:36:53 -08:00
|
|
|
# Session id info for continual prompting
|
2024-11-25 12:32:51 -08:00
|
|
|
session: Optional[
|
|
|
|
|
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
|
|
|
|
] = None
|
2024-11-20 00:36:53 -08:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
def normalize_batch_and_arguments(self):
|
2024-11-25 19:35:04 -05:00
|
|
|
if (
|
|
|
|
|
self.text is None and self.input_ids is None and self.input_embeds is None
|
|
|
|
|
) or (
|
|
|
|
|
self.text is not None
|
|
|
|
|
and self.input_ids is not None
|
|
|
|
|
and self.input_embeds is not None
|
2024-05-18 22:23:53 -07:00
|
|
|
):
|
2024-11-25 19:35:04 -05:00
|
|
|
raise ValueError(
|
|
|
|
|
"Either text, input_ids or input_embeds should be provided."
|
|
|
|
|
)
|
2024-08-28 06:33:05 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
# Derive the batch size
|
2024-10-01 10:25:32 -07:00
|
|
|
if self.text is not None:
|
|
|
|
|
if isinstance(self.text, str):
|
|
|
|
|
self.is_single = True
|
|
|
|
|
self.batch_size = 1
|
|
|
|
|
else:
|
2024-11-03 08:38:26 -08:00
|
|
|
self.is_single = False
|
2024-10-01 10:25:32 -07:00
|
|
|
self.batch_size = len(self.text)
|
2024-11-25 19:35:04 -05:00
|
|
|
self.input_embeds = None
|
|
|
|
|
elif self.input_ids is not None:
|
2024-10-01 10:25:32 -07:00
|
|
|
if isinstance(self.input_ids[0], int):
|
|
|
|
|
self.is_single = True
|
|
|
|
|
self.batch_size = 1
|
2024-07-20 14:10:01 +08:00
|
|
|
else:
|
2024-11-03 08:38:26 -08:00
|
|
|
self.is_single = False
|
2024-10-01 10:25:32 -07:00
|
|
|
self.batch_size = len(self.input_ids)
|
2024-11-25 19:35:04 -05:00
|
|
|
self.input_embeds = None
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(self.input_embeds[0][0], float):
|
|
|
|
|
self.is_single = True
|
|
|
|
|
self.batch_size = 1
|
|
|
|
|
else:
|
|
|
|
|
self.batch_size = len(self.input_embeds)
|
2024-10-01 10:25:32 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
# Handle parallel sampling
|
|
|
|
|
# When parallel sampling is used, we always treat the input as a batch.
|
2024-10-01 10:25:32 -07:00
|
|
|
if self.sampling_params is None:
|
|
|
|
|
self.parallel_sample_num = 1
|
2024-10-05 17:59:05 -07:00
|
|
|
elif isinstance(self.sampling_params, dict):
|
2024-10-01 10:25:32 -07:00
|
|
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
|
|
|
|
else: # isinstance(self.sampling_params, list):
|
|
|
|
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
2024-11-07 15:42:47 -08:00
|
|
|
assert all(
|
|
|
|
|
self.parallel_sample_num == sampling_params.get("n", 1)
|
|
|
|
|
for sampling_params in self.sampling_params
|
|
|
|
|
), "The parallel_sample_num should be the same for all samples in sample params."
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
if self.parallel_sample_num > 1 and self.is_single:
|
|
|
|
|
self.is_single = False
|
|
|
|
|
if self.text is not None:
|
|
|
|
|
self.text = [self.text]
|
|
|
|
|
if self.input_ids is not None:
|
|
|
|
|
self.input_ids = [self.input_ids]
|
|
|
|
|
|
|
|
|
|
# Fill in default arguments
|
2024-10-01 10:25:32 -07:00
|
|
|
if self.is_single:
|
2024-01-08 04:37:50 +00:00
|
|
|
if self.sampling_params is None:
|
|
|
|
|
self.sampling_params = {}
|
|
|
|
|
if self.rid is None:
|
|
|
|
|
self.rid = uuid.uuid4().hex
|
2024-01-23 05:07:30 -08:00
|
|
|
if self.return_logprob is None:
|
|
|
|
|
self.return_logprob = False
|
|
|
|
|
if self.logprob_start_len is None:
|
2024-08-18 23:45:41 -07:00
|
|
|
self.logprob_start_len = -1
|
2024-03-28 14:34:49 +08:00
|
|
|
if self.top_logprobs_num is None:
|
|
|
|
|
self.top_logprobs_num = 0
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
2024-10-01 10:25:32 -07:00
|
|
|
if self.parallel_sample_num == 1:
|
|
|
|
|
num = self.batch_size
|
2024-07-30 04:07:18 +08:00
|
|
|
else:
|
2024-11-03 08:38:26 -08:00
|
|
|
# Expand parallel_sample_num
|
|
|
|
|
num = self.batch_size * self.parallel_sample_num
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
if self.image_data is None:
|
|
|
|
|
self.image_data = [None] * num
|
|
|
|
|
elif not isinstance(self.image_data, list):
|
|
|
|
|
self.image_data = [self.image_data] * num
|
2024-09-18 03:46:32 -04:00
|
|
|
elif isinstance(self.image_data, list):
|
2024-10-11 17:25:04 +08:00
|
|
|
pass
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
if self.sampling_params is None:
|
|
|
|
|
self.sampling_params = [{}] * num
|
|
|
|
|
elif not isinstance(self.sampling_params, list):
|
|
|
|
|
self.sampling_params = [self.sampling_params] * num
|
|
|
|
|
|
|
|
|
|
if self.rid is None:
|
|
|
|
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
|
|
|
|
else:
|
2024-10-01 10:25:32 -07:00
|
|
|
assert isinstance(self.rid, list), "The rid should be a list."
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-01-23 05:07:30 -08:00
|
|
|
if self.return_logprob is None:
|
|
|
|
|
self.return_logprob = [False] * num
|
|
|
|
|
elif not isinstance(self.return_logprob, list):
|
|
|
|
|
self.return_logprob = [self.return_logprob] * num
|
2024-10-01 10:25:32 -07:00
|
|
|
else:
|
|
|
|
|
assert self.parallel_sample_num == 1
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-01-23 05:07:30 -08:00
|
|
|
if self.logprob_start_len is None:
|
2024-08-18 23:45:41 -07:00
|
|
|
self.logprob_start_len = [-1] * num
|
2024-01-23 05:07:30 -08:00
|
|
|
elif not isinstance(self.logprob_start_len, list):
|
|
|
|
|
self.logprob_start_len = [self.logprob_start_len] * num
|
2024-10-01 10:25:32 -07:00
|
|
|
else:
|
|
|
|
|
assert self.parallel_sample_num == 1
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-03-28 14:34:49 +08:00
|
|
|
if self.top_logprobs_num is None:
|
|
|
|
|
self.top_logprobs_num = [0] * num
|
|
|
|
|
elif not isinstance(self.top_logprobs_num, list):
|
|
|
|
|
self.top_logprobs_num = [self.top_logprobs_num] * num
|
2024-10-01 10:25:32 -07:00
|
|
|
else:
|
|
|
|
|
assert self.parallel_sample_num == 1
|
2024-03-28 14:34:49 +08:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
def regenerate_rid(self):
|
|
|
|
|
self.rid = uuid.uuid4().hex
|
|
|
|
|
return self.rid
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
|
return GenerateReqInput(
|
|
|
|
|
text=self.text[i] if self.text is not None else None,
|
|
|
|
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
|
|
|
|
image_data=self.image_data[i],
|
|
|
|
|
sampling_params=self.sampling_params[i],
|
|
|
|
|
rid=self.rid[i],
|
|
|
|
|
return_logprob=self.return_logprob[i],
|
|
|
|
|
logprob_start_len=self.logprob_start_len[i],
|
|
|
|
|
top_logprobs_num=self.top_logprobs_num[i],
|
|
|
|
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
|
|
|
|
stream=self.stream,
|
|
|
|
|
modalities=self.modalities[i] if self.modalities else None,
|
|
|
|
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
|
|
|
|
)
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TokenizedGenerateReqInput:
|
2024-08-17 14:37:52 -07:00
|
|
|
# The request id
|
2024-01-08 04:37:50 +00:00
|
|
|
rid: str
|
2024-08-17 14:37:52 -07:00
|
|
|
# The input text
|
2024-01-25 01:16:25 +08:00
|
|
|
input_text: str
|
2024-08-17 14:37:52 -07:00
|
|
|
# The input token ids
|
2024-01-08 04:37:50 +00:00
|
|
|
input_ids: List[int]
|
2024-11-11 16:34:10 -08:00
|
|
|
# The image inputs
|
2024-09-28 23:28:55 -07:00
|
|
|
image_inputs: dict
|
2024-08-17 14:37:52 -07:00
|
|
|
# The sampling parameters
|
2024-01-08 04:37:50 +00:00
|
|
|
sampling_params: SamplingParams
|
2024-08-17 14:37:52 -07:00
|
|
|
# Whether to return the logprobs
|
2024-01-23 05:07:30 -08:00
|
|
|
return_logprob: bool
|
2024-08-17 14:37:52 -07:00
|
|
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
2024-01-23 05:07:30 -08:00
|
|
|
logprob_start_len: int
|
2024-08-17 14:37:52 -07:00
|
|
|
# If return logprobs, the number of top logprobs to return at each position.
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_num: int
|
2024-08-17 14:37:52 -07:00
|
|
|
# Whether to stream output
|
2024-01-08 04:37:50 +00:00
|
|
|
stream: bool
|
|
|
|
|
|
2024-09-12 16:46:14 -07:00
|
|
|
# LoRA related
|
|
|
|
|
lora_path: Optional[str] = None # None means just use the base model
|
2024-11-25 19:35:04 -05:00
|
|
|
# The input embeds
|
|
|
|
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
2024-09-12 16:46:14 -07:00
|
|
|
|
2024-11-20 00:36:53 -08:00
|
|
|
# Session id info for continual prompting
|
2024-11-25 12:32:51 -08:00
|
|
|
session_id: Optional[str] = None
|
2024-11-20 00:36:53 -08:00
|
|
|
session_rid: Optional[str] = None
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-08 00:52:31 -07:00
|
|
|
@dataclass
|
|
|
|
|
class EmbeddingReqInput:
|
|
|
|
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
|
|
|
text: Optional[Union[List[str], str]] = None
|
|
|
|
|
# The token ids for text; one can either specify text or input_ids.
|
|
|
|
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
|
|
|
|
# The request id.
|
|
|
|
|
rid: Optional[Union[List[str], str]] = None
|
|
|
|
|
# Dummy sampling params for compatibility
|
|
|
|
|
sampling_params: Union[List[Dict], Dict] = None
|
2024-11-25 19:35:04 -05:00
|
|
|
# Dummy input embeds for compatibility
|
|
|
|
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
2024-08-08 00:52:31 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
def normalize_batch_and_arguments(self):
|
2024-08-08 00:52:31 -07:00
|
|
|
if (self.text is None and self.input_ids is None) or (
|
|
|
|
|
self.text is not None and self.input_ids is not None
|
|
|
|
|
):
|
|
|
|
|
raise ValueError("Either text or input_ids should be provided.")
|
|
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
# Derive the batch size
|
2024-08-08 00:52:31 -07:00
|
|
|
if self.text is not None:
|
2024-11-03 08:38:26 -08:00
|
|
|
if isinstance(self.text, str):
|
|
|
|
|
self.is_single = True
|
|
|
|
|
self.batch_size = 1
|
|
|
|
|
else:
|
|
|
|
|
self.is_single = False
|
|
|
|
|
self.batch_size = len(self.text)
|
2024-08-08 00:52:31 -07:00
|
|
|
else:
|
2024-11-03 08:38:26 -08:00
|
|
|
if isinstance(self.input_ids[0], int):
|
|
|
|
|
self.is_single = True
|
|
|
|
|
self.batch_size = 1
|
|
|
|
|
else:
|
|
|
|
|
self.is_single = False
|
|
|
|
|
self.batch_size = len(self.input_ids)
|
2024-08-08 00:52:31 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
# Fill in default arguments
|
2024-09-27 23:32:11 -07:00
|
|
|
if self.is_single:
|
2024-08-08 00:52:31 -07:00
|
|
|
if self.rid is None:
|
|
|
|
|
self.rid = uuid.uuid4().hex
|
2024-08-09 11:19:18 -07:00
|
|
|
if self.sampling_params is None:
|
2024-08-10 13:46:42 -07:00
|
|
|
self.sampling_params = {}
|
2024-11-03 13:27:12 -08:00
|
|
|
self.sampling_params["max_new_tokens"] = 0
|
2024-08-08 00:52:31 -07:00
|
|
|
else:
|
|
|
|
|
if self.rid is None:
|
|
|
|
|
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
|
|
|
|
else:
|
2024-11-03 08:38:26 -08:00
|
|
|
assert isinstance(self.rid, list), "The rid should be a list."
|
|
|
|
|
|
2024-08-09 11:19:18 -07:00
|
|
|
if self.sampling_params is None:
|
2024-08-10 13:46:42 -07:00
|
|
|
self.sampling_params = [{}] * self.batch_size
|
|
|
|
|
for i in range(self.batch_size):
|
2024-11-03 13:27:12 -08:00
|
|
|
self.sampling_params[i]["max_new_tokens"] = 0
|
2024-08-08 00:52:31 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
def regenerate_rid(self):
|
|
|
|
|
self.rid = uuid.uuid4().hex
|
|
|
|
|
return self.rid
|
2024-08-08 00:52:31 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
def __getitem__(self, i):
|
|
|
|
|
return EmbeddingReqInput(
|
|
|
|
|
text=self.text[i] if self.text is not None else None,
|
|
|
|
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
|
|
|
|
sampling_params=self.sampling_params[i],
|
|
|
|
|
rid=self.rid[i],
|
|
|
|
|
)
|
2024-09-27 23:32:11 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-11-03 08:38:26 -08:00
|
|
|
class TokenizedEmbeddingReqInput:
|
2024-09-27 23:32:11 -07:00
|
|
|
# The request id
|
|
|
|
|
rid: str
|
|
|
|
|
# The input text
|
|
|
|
|
input_text: str
|
|
|
|
|
# The input token ids
|
|
|
|
|
input_ids: List[int]
|
|
|
|
|
# Dummy sampling params for compatibility
|
|
|
|
|
sampling_params: SamplingParams
|
|
|
|
|
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
@dataclass
|
|
|
|
|
class BatchTokenIDOut:
|
2024-08-17 14:37:52 -07:00
|
|
|
# The request id
|
2024-01-08 04:37:50 +00:00
|
|
|
rids: List[str]
|
2024-12-08 12:27:13 -08:00
|
|
|
# The finish reason
|
|
|
|
|
finished_reasons: List[BaseFinishReason]
|
|
|
|
|
# For incremental decoding
|
2024-08-17 14:37:52 -07:00
|
|
|
# The version id to sync decode status with in detokenizer_manager
|
2024-07-19 16:42:06 -07:00
|
|
|
vids: List[int]
|
2024-06-12 14:39:12 +08:00
|
|
|
decoded_texts: List[str]
|
2024-07-18 17:57:40 -07:00
|
|
|
decode_ids: List[int]
|
|
|
|
|
read_offsets: List[int]
|
2024-10-25 18:51:59 -07:00
|
|
|
# Only used when `--skip-tokenizer-init`
|
|
|
|
|
output_ids: Optional[List[int]]
|
2024-12-08 12:27:13 -08:00
|
|
|
# Detokenization configs
|
2024-01-08 04:37:50 +00:00
|
|
|
skip_special_tokens: List[bool]
|
2024-05-01 07:17:12 +08:00
|
|
|
spaces_between_special_tokens: List[bool]
|
2024-10-13 20:30:03 -07:00
|
|
|
no_stop_trim: List[bool]
|
2024-12-08 12:27:13 -08:00
|
|
|
# Token counts
|
|
|
|
|
prompt_tokens: List[int]
|
|
|
|
|
completion_tokens: List[int]
|
|
|
|
|
cached_tokens: List[int]
|
|
|
|
|
# Logprobs
|
|
|
|
|
input_token_logprobs_val: List[float]
|
|
|
|
|
input_token_logprobs_idx: List[int]
|
|
|
|
|
output_token_logprobs_val: List[float]
|
|
|
|
|
output_token_logprobs_idx: List[int]
|
|
|
|
|
input_top_logprobs_val: List[List]
|
|
|
|
|
input_top_logprobs_idx: List[List]
|
|
|
|
|
output_top_logprobs_val: List[List]
|
|
|
|
|
output_top_logprobs_idx: List[List]
|
|
|
|
|
normalized_prompt_logprob: List[float]
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-06-12 14:39:12 +08:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
@dataclass
|
|
|
|
|
class BatchStrOut:
|
2024-08-17 14:37:52 -07:00
|
|
|
# The request id
|
2024-01-08 04:37:50 +00:00
|
|
|
rids: List[str]
|
2024-12-08 12:27:13 -08:00
|
|
|
# The finish reason
|
|
|
|
|
finished_reasons: List[dict]
|
2024-08-17 14:37:52 -07:00
|
|
|
# The output decoded strings
|
2024-07-07 05:53:22 +08:00
|
|
|
output_strs: List[str]
|
2024-12-08 12:27:13 -08:00
|
|
|
|
|
|
|
|
# Token counts
|
|
|
|
|
prompt_tokens: List[int]
|
|
|
|
|
completion_tokens: List[int]
|
|
|
|
|
cached_tokens: List[int]
|
|
|
|
|
# Logprobs
|
|
|
|
|
input_token_logprobs_val: List[float]
|
|
|
|
|
input_token_logprobs_idx: List[int]
|
|
|
|
|
output_token_logprobs_val: List[float]
|
|
|
|
|
output_token_logprobs_idx: List[int]
|
|
|
|
|
input_top_logprobs_val: List[List]
|
|
|
|
|
input_top_logprobs_idx: List[List]
|
|
|
|
|
output_top_logprobs_val: List[List]
|
|
|
|
|
output_top_logprobs_idx: List[List]
|
|
|
|
|
normalized_prompt_logprob: List[float]
|
2024-01-26 13:32:59 +08:00
|
|
|
|
|
|
|
|
|
2024-08-08 00:52:31 -07:00
|
|
|
@dataclass
|
|
|
|
|
class BatchEmbeddingOut:
|
2024-08-17 14:37:52 -07:00
|
|
|
# The request id
|
2024-08-08 00:52:31 -07:00
|
|
|
rids: List[str]
|
2024-12-08 12:27:13 -08:00
|
|
|
# The finish reason
|
|
|
|
|
finished_reasons: List[BaseFinishReason]
|
2024-08-17 14:37:52 -07:00
|
|
|
# The output embedding
|
2024-08-08 00:52:31 -07:00
|
|
|
embeddings: List[List[float]]
|
2024-12-08 12:27:13 -08:00
|
|
|
# Token counts
|
|
|
|
|
prompt_tokens: List[int]
|
2024-08-08 00:52:31 -07:00
|
|
|
|
|
|
|
|
|
2024-01-26 13:32:59 +08:00
|
|
|
@dataclass
|
|
|
|
|
class FlushCacheReq:
|
|
|
|
|
pass
|
2024-02-06 12:24:55 -08:00
|
|
|
|
2024-02-06 13:27:46 -08:00
|
|
|
|
2024-08-20 13:48:24 -07:00
|
|
|
@dataclass
|
2024-11-29 17:17:00 -08:00
|
|
|
class UpdateWeightFromDiskReqInput:
|
2024-08-20 13:48:24 -07:00
|
|
|
# The model path with the new weights
|
|
|
|
|
model_path: str
|
|
|
|
|
# The format to load the weights
|
|
|
|
|
load_format: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-11-29 17:17:00 -08:00
|
|
|
class UpdateWeightFromDiskReqOutput:
|
2024-08-20 13:48:24 -07:00
|
|
|
success: bool
|
|
|
|
|
message: str
|
|
|
|
|
|
|
|
|
|
|
2024-12-01 23:23:18 -08:00
|
|
|
@dataclass
|
|
|
|
|
class UpdateWeightsFromDistributedReqInput:
|
|
|
|
|
name: str
|
|
|
|
|
dtype: str
|
|
|
|
|
shape: List[int]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class UpdateWeightsFromDistributedReqOutput:
|
|
|
|
|
success: bool
|
|
|
|
|
message: str
|
|
|
|
|
|
|
|
|
|
|
2024-12-29 05:30:27 +08:00
|
|
|
@dataclass
|
|
|
|
|
class UpdateWeightsFromTensorReqInput:
|
|
|
|
|
name: str
|
|
|
|
|
tensor: torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class UpdateWeightsFromTensorReqOutput:
|
|
|
|
|
success: bool
|
|
|
|
|
message: str
|
|
|
|
|
|
|
|
|
|
|
2024-12-01 23:23:18 -08:00
|
|
|
@dataclass
|
|
|
|
|
class InitWeightsUpdateGroupReqInput:
|
|
|
|
|
# The master address
|
|
|
|
|
master_address: str
|
|
|
|
|
# The master port
|
|
|
|
|
master_port: int
|
|
|
|
|
# The rank offset
|
|
|
|
|
rank_offset: int
|
|
|
|
|
# The world size
|
|
|
|
|
world_size: int
|
|
|
|
|
# The group name
|
|
|
|
|
group_name: str = "weight_update_group"
|
|
|
|
|
# The backend
|
|
|
|
|
backend: str = "nccl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class InitWeightsUpdateGroupReqOutput:
|
|
|
|
|
success: bool
|
|
|
|
|
message: str
|
|
|
|
|
|
|
|
|
|
|
2024-11-29 23:36:38 -08:00
|
|
|
@dataclass
|
|
|
|
|
class GetWeightsByNameReqInput:
|
|
|
|
|
name: str
|
|
|
|
|
truncate_size: int = 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class GetWeightsByNameReqOutput:
|
|
|
|
|
parameter: list
|
|
|
|
|
|
|
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
@dataclass
|
|
|
|
|
class AbortReq:
|
2024-08-17 14:37:52 -07:00
|
|
|
# The request id
|
2024-05-17 05:49:31 -07:00
|
|
|
rid: str
|
2024-10-11 17:34:25 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProfileReq(Enum):
|
|
|
|
|
START_PROFILE = 1
|
|
|
|
|
STOP_PROFILE = 2
|
2024-10-23 00:02:29 -07:00
|
|
|
|
|
|
|
|
|
2024-11-20 00:36:53 -08:00
|
|
|
@dataclass
|
|
|
|
|
class OpenSessionReqInput:
|
|
|
|
|
capacity_of_str_len: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class CloseSessionReqInput:
|
|
|
|
|
session_id: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class OpenSessionReqOutput:
|
|
|
|
|
session_id: str
|