Simplify tokenizer manager (#1904)

This commit is contained in:
Lianmin Zheng
2024-11-03 08:38:26 -08:00
committed by GitHub
parent 916b3cdddc
commit c17c578108
11 changed files with 261 additions and 443 deletions

View File

@@ -56,49 +56,47 @@ class GenerateReqInput:
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self):
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")
self.is_single = False
# Derive the batch size
if self.text is not None:
if isinstance(self.text, str):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.text)
else:
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.input_ids)
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
if self.sampling_params is None:
self.parallel_sample_num = 1
elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
for sp in self.sampling_params:
# TODO cope with the case that the parallel_sample_num is different for different samples
assert self.parallel_sample_num == sp.get(
"n", 1
), "The parallel_sample_num should be the same for all samples in sample params."
assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
"The parallel_sample_num should be the same for all samples in sample params.")
if self.parallel_sample_num > 1:
if self.is_single:
self.is_single = False
if self.text is not None:
self.text = [self.text]
if self.input_ids is not None:
self.input_ids = [self.input_ids]
if self.parallel_sample_num > 1 and self.is_single:
self.is_single = False
if self.text is not None:
self.text = [self.text]
if self.input_ids is not None:
self.input_ids = [self.input_ids]
# Fill in default arguments
if self.is_single:
if self.sampling_params is None:
self.sampling_params = {}
@@ -114,8 +112,8 @@ class GenerateReqInput:
if self.parallel_sample_num == 1:
num = self.batch_size
else:
# The first bs samples are used for caching the prefix for parallel sampling
num = self.batch_size + self.parallel_sample_num * self.batch_size
# Expand parallel_sample_num
num = self.batch_size * self.parallel_sample_num
if self.image_data is None:
self.image_data = [None] * num
@@ -128,14 +126,11 @@ class GenerateReqInput:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
else:
assert self.parallel_sample_num == 1
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
assert isinstance(self.rid, list), "The rid should be a list."
assert self.parallel_sample_num == 1
if self.return_logprob is None:
self.return_logprob = [False] * num
@@ -158,6 +153,26 @@ class GenerateReqInput:
else:
assert self.parallel_sample_num == 1
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
def __getitem__(self, i):
return GenerateReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i],
sampling_params=self.sampling_params[i],
rid=self.rid[i],
return_logprob=self.return_logprob[i],
logprob_start_len=self.logprob_start_len[i],
top_logprobs_num=self.top_logprobs_num[i],
return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream,
modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
)
@dataclass
class TokenizedGenerateReqInput:
@@ -195,20 +210,29 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self):
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")
# Derive the batch size
if self.text is not None:
self.is_single = isinstance(self.text, str)
if isinstance(self.text, str):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.text)
else:
self.is_single = isinstance(self.input_ids[0], int)
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.input_ids)
# Fill in default arguments
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
@@ -216,20 +240,28 @@ class EmbeddingReqInput:
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = (
len(self.text) if self.text is not None else len(self.input_ids)
)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
assert isinstance(self.rid, list), "The rid should be a list."
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
def __getitem__(self, i):
return EmbeddingReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
sampling_params=self.sampling_params[i],
rid=self.rid[i],
)
@dataclass
class TokenizedEmbeddingReqInput:
@@ -243,56 +275,6 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams
RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]]
@dataclass
class RewardReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
conv: RewardReqConv
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self):
self.is_single = isinstance(self.conv[0], dict)
if self.is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = len(self.conv)
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1
@dataclass
class TokenizedRewardReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
@dataclass
class BatchTokenIDOut:
# The request id