Let bench_one_batch_server use sharegpt data to make expert distribution more natural (#5573)
This commit is contained in:
@@ -471,6 +471,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
def get_tokenizer(
|
||||
pretrained_model_name_or_path: str,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
assert (
|
||||
pretrained_model_name_or_path is not None
|
||||
and pretrained_model_name_or_path != ""
|
||||
)
|
||||
if pretrained_model_name_or_path.endswith(
|
||||
".json"
|
||||
) or pretrained_model_name_or_path.endswith(".model"):
|
||||
@@ -832,6 +836,7 @@ def sample_random_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dataset_path: str,
|
||||
random_sample: bool = True,
|
||||
return_text: bool = True,
|
||||
) -> List[DatasetRow]:
|
||||
input_lens = np.random.randint(
|
||||
max(int(input_len * range_ratio), 1),
|
||||
@@ -892,10 +897,12 @@ def sample_random_requests(
|
||||
else:
|
||||
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
|
||||
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
|
||||
prompt = tokenizer.decode(input_ids)
|
||||
input_content = input_ids
|
||||
if return_text:
|
||||
input_content = tokenizer.decode(input_content)
|
||||
input_requests.append(
|
||||
DatasetRow(
|
||||
prompt=prompt,
|
||||
prompt=input_content,
|
||||
prompt_len=int(input_lens[i]),
|
||||
output_len=int(output_lens[i]),
|
||||
)
|
||||
@@ -905,15 +912,15 @@ def sample_random_requests(
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode(
|
||||
[
|
||||
(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])
|
||||
]
|
||||
)
|
||||
input_content = [
|
||||
(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])
|
||||
]
|
||||
if return_text:
|
||||
input_content = tokenizer.decode(input_content)
|
||||
input_requests.append(
|
||||
DatasetRow(
|
||||
prompt=prompt,
|
||||
prompt=input_content,
|
||||
prompt_len=int(input_lens[i]),
|
||||
output_len=int(output_lens[i]),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user