diff --git a/python/sglang/environ.py b/python/sglang/environ.py index ce2e39032..12470ba9a 100644 --- a/python/sglang/environ.py +++ b/python/sglang/environ.py @@ -197,6 +197,11 @@ class Envs: SGLANG_SYNC_TOKEN_IDS_ACROSS_TP = EnvBool(False) SGLANG_ENABLE_COLOCATED_BATCH_GEN = EnvBool(False) + # Deterministic inference + SGLANG_ENABLE_DETERMINISTIC_INFERENCE = EnvBool(False) + SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE = EnvInt(4096) + SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE = EnvInt(2048) + # fmt: on diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index b761c8423..aaa8b520b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -31,6 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.utils import ( + get_int_env_var, is_flashinfer_available, is_sm100_supported, next_power_of_2, @@ -40,6 +41,7 @@ if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner + if is_flashinfer_available(): from flashinfer import ( BatchDecodeWithPagedKVCacheWrapper, @@ -123,12 +125,33 @@ class FlashInferAttnBackend(AttentionBackend): ): global_config.flashinfer_workspace_size = 512 * 1024 * 1024 + # When deterministic inference is enabled, tensor cores should be used for decode + # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph + # More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675 + self.enable_deterministic = ( + model_runner.server_args.enable_deterministic_inference + ) + self.prefill_split_tile_size = None + self.decode_split_tile_size = None + self.disable_cuda_graph_kv_split = False + if self.enable_deterministic: + self.decode_use_tensor_cores = True + self.prefill_split_tile_size = get_int_env_var( + "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096 + ) + self.decode_split_tile_size = get_int_env_var( + "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048 + ) + self.disable_cuda_graph_kv_split = True + global_config.flashinfer_workspace_size = 2048 * 1024 * 1024 + # Allocate buffers global global_workspace_buffer if global_workspace_buffer is None: # different from flashinfer zero_init_global_workspace_buffer + global_workspace_size = global_config.flashinfer_workspace_size global_workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, + global_workspace_size, dtype=torch.uint8, device=model_runner.device, ) @@ -219,6 +242,8 @@ class FlashInferAttnBackend(AttentionBackend): decode_wrappers=self.decode_wrappers, encoder_lens=forward_batch.encoder_lens, spec_info=forward_batch.spec_info, + fixed_split_size=self.decode_split_tile_size, + disable_split_kv=False, ) self.forward_metadata = DecodeMetadata(self.decode_wrappers) elif forward_batch.forward_mode.is_draft_extend(): @@ -258,7 +283,7 @@ class FlashInferAttnBackend(AttentionBackend): use_ragged = False extend_no_prefix = False else: - use_ragged = True + use_ragged = not self.enable_deterministic extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) self.indices_updater_prefill.update( @@ -271,6 +296,7 @@ class FlashInferAttnBackend(AttentionBackend): use_ragged=use_ragged, encoder_lens=forward_batch.encoder_lens, spec_info=None, + fixed_split_size=self.prefill_split_tile_size, ) self.forward_metadata = PrefillMetadata( self.prefill_wrappers_paged, use_ragged, extend_no_prefix @@ -347,6 +373,8 @@ class FlashInferAttnBackend(AttentionBackend): decode_wrappers=decode_wrappers, encoder_lens=encoder_lens, spec_info=spec_info, + fixed_split_size=None, + disable_split_kv=self.disable_cuda_graph_kv_split, ) self.decode_cuda_graph_metadata[bs] = decode_wrappers self.forward_metadata = DecodeMetadata(decode_wrappers) @@ -439,6 +467,8 @@ class FlashInferAttnBackend(AttentionBackend): decode_wrappers=self.decode_cuda_graph_metadata[bs], encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, spec_info=spec_info, + fixed_split_size=None, + disable_split_kv=self.disable_cuda_graph_kv_split, ) elif forward_mode.is_target_verify(): self.indices_updater_prefill.update( @@ -646,6 +676,8 @@ class FlashInferIndicesUpdaterDecode: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, + disable_split_kv: Optional[bool] = None, ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -661,6 +693,8 @@ class FlashInferIndicesUpdaterDecode: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, + disable_split_kv: Optional[bool] = None, ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -672,6 +706,8 @@ class FlashInferIndicesUpdaterDecode: None, spec_info, seq_lens_cpu, + fixed_split_size=fixed_split_size, + disable_split_kv=disable_split_kv, ) def update_sliding_window( @@ -685,6 +721,8 @@ class FlashInferIndicesUpdaterDecode: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, + disable_split_kv: Optional[bool] = None, ): assert self.sliding_window_size is not None for wrapper_id in range(2): @@ -735,6 +773,8 @@ class FlashInferIndicesUpdaterDecode: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, + disable_split_kv: Optional[bool] = None, ): for wrapper_id in range(2): if wrapper_id == 0: @@ -771,6 +811,8 @@ class FlashInferIndicesUpdaterDecode: ], seq_lens_cpu: Optional[torch.Tensor], use_sliding_window_kv_pool: bool = False, + fixed_split_size: Optional[int] = None, + disable_split_kv: Optional[bool] = None, ): if spec_info is None: bs = len(req_pool_indices) @@ -825,6 +867,10 @@ class FlashInferIndicesUpdaterDecode: data_type=self.data_type, q_data_type=self.q_data_type, non_blocking=True, + fixed_split_size=fixed_split_size, + disable_split_kv=( + disable_split_kv if disable_split_kv is not None else False + ), ) if locally_override: @@ -876,6 +922,7 @@ class FlashInferIndicesUpdaterPrefill: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -893,6 +940,7 @@ class FlashInferIndicesUpdaterPrefill: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, ): if use_ragged: # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu @@ -916,6 +964,7 @@ class FlashInferIndicesUpdaterPrefill: self.qo_indptr[0], use_ragged, spec_info, + fixed_split_size=fixed_split_size, ) def update_sliding_window( @@ -931,6 +980,7 @@ class FlashInferIndicesUpdaterPrefill: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, ): for wrapper_id in range(2): if wrapper_id == 0: @@ -979,6 +1029,7 @@ class FlashInferIndicesUpdaterPrefill: spec_info: Optional[ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], + fixed_split_size: Optional[int] = None, ): for wrapper_id in range(2): if wrapper_id == 0: @@ -1024,6 +1075,7 @@ class FlashInferIndicesUpdaterPrefill: Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] ], use_sliding_window_kv_pool: bool = False, + fixed_split_size: Optional[int] = None, ): bs = len(seq_lens) if spec_info is None: @@ -1094,6 +1146,7 @@ class FlashInferIndicesUpdaterPrefill: kv_data_type=self.data_type, custom_mask=custom_mask, non_blocking=True, + fixed_split_size=fixed_split_size, ) @@ -1327,6 +1380,8 @@ def fast_decode_plan( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, non_blocking: bool = True, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, ) -> None: """ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend. @@ -1352,6 +1407,9 @@ def fast_decode_plan( if self.use_tensor_cores: qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + # Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function + if fixed_split_size is None: + fixed_split_size = -1 if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: @@ -1433,8 +1491,8 @@ def fast_decode_plan( head_dim, False, # causal window_left, - -1, - False, + fixed_split_size, + disable_split_kv, ) except Exception as e: raise RuntimeError(f"Error in standard plan: {e}") diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 59489cdb8..b766ff60b 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -14,6 +14,7 @@ """Fused operators for normalization layers.""" import logging +import os from typing import Optional, Tuple, Union import torch @@ -80,6 +81,8 @@ class RMSNorm(CustomOp): ) if _use_aiter: self._forward_method = self.forward_aiter + if os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] == "1": + self._forward_method = self.forward_native def forward_cuda( self, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9eaebad7a..06dea43d4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -111,6 +111,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "enable_symm_mem", "enable_custom_logit_processor", "disaggregation_mode", + "enable_deterministic_inference", ] # Put some global args for easy access diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 3e8877faf..a59dffd75 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -541,7 +541,9 @@ class PrefillAdder: return self.budget_state() - def add_one_req(self, req: Req, has_chunked_req: bool): + def add_one_req( + self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int] + ): if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): return self.add_one_req_ignore_eos(req, has_chunked_req) @@ -600,6 +602,17 @@ class PrefillAdder: if trunc_len <= 0: return AddReqResult.OTHER + # When truncation align size is set, we want to assert that the prefill prefix length is multiple of truncation align size + # A typical use case is when deterministic inference is enabled with flashinfer attention backend, + # we need the prefill prefix length to be multiple of attention split size + if truncation_align_size is not None: + if trunc_len < truncation_align_size: + return AddReqResult.OTHER + else: + trunc_len = truncation_align_size * ( + trunc_len // truncation_align_size + ) + # Chunked prefill req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fc5525afa..44dbc7d54 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -172,6 +172,7 @@ from sglang.srt.utils import ( freeze_gc, get_available_gpu_memory, get_bool_env_var, + get_int_env_var, get_zmq_socket, is_cpu, kill_itself_when_parent_died, @@ -565,6 +566,17 @@ class Scheduler( if get_bool_env_var("SGLANG_GC_LOG"): configure_gc_logger() + # Init prefill kv split size when deterministic inference is enabled with flashinfer attention backend + if ( + self.server_args.enable_deterministic_inference + and self.server_args.attention_backend == "flashinfer" + ): + self.truncation_align_size = get_int_env_var( + "SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096 + ) + else: + self.truncation_align_size = None + # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ @@ -1846,7 +1858,11 @@ class Scheduler( continue req.init_next_round_input(self.tree_cache) - res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) + res = adder.add_one_req( + req, + has_chunked_req=(self.chunked_req is not None), + truncation_align_size=self.truncation_align_size, + ) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6384532cd..210b21349 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -406,6 +406,12 @@ class ModelRunner: ) self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type) + # Enable batch invariant mode + if server_args.enable_deterministic_inference: + from batch_invariant_ops import enable_batch_invariant_mode + + enable_batch_invariant_mode() + # Init memory pool and attention backends self.init_memory_pool( min_per_gpu_memory, diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6ba8a7777..7fb48a286 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -75,6 +75,7 @@ class SamplingBatchInfo: @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): global_server_args_dict = cls._get_global_server_args_dict() + enable_deterministic = global_server_args_dict["enable_deterministic_inference"] reqs = batch.reqs device = batch.device diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c1fb677b3..331897ae4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -118,6 +118,8 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] +DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"] + # Allow external code to add more choices def add_load_format_choices(choices): @@ -437,6 +439,9 @@ class ServerArgs: max_mamba_cache_size: Optional[int] = None mamba_ssm_dtype: str = "float32" + # For deterministic inference + enable_deterministic_inference: bool = False + # Deprecated arguments enable_ep_moe: bool = False enable_deepep_moe: bool = False @@ -980,6 +985,29 @@ class ServerArgs: "Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels." ) + # Deterministic inference + os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = ( + "1" if self.enable_deterministic_inference else "0" + ) + if self.enable_deterministic_inference: + # Check batch_invariant_ops dependency + import importlib + + if not importlib.util.find_spec("batch_invariant_ops"): + raise ValueError( + "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/." + ) + + # Check some settings + self.disable_radix_cache = True + logger.warning( + "Currently radix cache is disabled for deterministic inference. It will be supported in the future." + ) + if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: + raise ValueError( + f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." + ) + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and tokenizer @@ -2470,6 +2498,13 @@ class ServerArgs: help="Number of sm partition groups.", ) + # For deterministic inference + parser.add_argument( + "--enable-deterministic-inference", + action="store_true", + help="Enable deterministic inference mode with batch invariant ops.", + ) + # Deprecated arguments parser.add_argument( "--enable-ep-moe", diff --git a/python/sglang/test/test_deterministic.py b/python/sglang/test/test_deterministic.py new file mode 100644 index 000000000..7404d201f --- /dev/null +++ b/python/sglang/test/test_deterministic.py @@ -0,0 +1,283 @@ +""" +Batch the same prompt in random batch sizes, and test if the results are consistent across different trials. + +Usage: +python3 -m sglang.test.test_deterministic --n-trials --test-mode --profile +""" + +import argparse +import dataclasses +import json +import os +import random +from typing import List + +import requests + +from sglang.profiler import run_profile + +PROMPT_1 = "Tell me about Richard Feynman: " +PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number." +dirpath = os.path.dirname(__file__) +with open("python/sglang/test/long_prompt.txt", "r") as f: + LONG_PROMPT = f.read() + + +@dataclasses.dataclass +class BenchArgs: + host: str = "localhost" + port: int = 30000 + batch_size: int = 1 + temperature: float = 0.0 + max_new_tokens: int = 100 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + return_logprob: bool = False + stream: bool = False + profile: bool = False + profile_steps: int = 3 + profile_by_stage: bool = False + test_mode: str = "single" + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--host", type=str, default=BenchArgs.host) + parser.add_argument("--port", type=int, default=BenchArgs.port) + parser.add_argument("--n-trials", type=int, default=50) + parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) + parser.add_argument( + "--max-new-tokens", type=int, default=BenchArgs.max_new_tokens + ) + parser.add_argument( + "--frequency-penalty", type=float, default=BenchArgs.frequency_penalty + ) + parser.add_argument( + "--presence-penalty", type=float, default=BenchArgs.presence_penalty + ) + parser.add_argument("--return-logprob", action="store_true") + parser.add_argument("--stream", action="store_true") + parser.add_argument( + "--test-mode", + type=str, + default=BenchArgs.test_mode, + choices=["single", "mixed", "prefix"], + ) + parser.add_argument("--profile", action="store_true") + parser.add_argument( + "--profile-steps", type=int, default=BenchArgs.profile_steps + ) + parser.add_argument("--profile-by-stage", action="store_true") + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + attrs = [attr.name for attr in dataclasses.fields(cls)] + return cls(**{attr: getattr(args, attr) for attr in attrs}) + + +def send_single( + args, + batch_size: int, + profile: bool = False, + profile_steps: int = 3, + profile_by_stage: bool = False, +): + + base_url = f"http://{args.host}:{args.port}" + prompt = [PROMPT_1] * batch_size + + json_data = { + "text": prompt, + "sampling_params": { + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "frequency_penalty": args.frequency_penalty, + "presence_penalty": args.presence_penalty, + }, + "return_logprob": args.return_logprob, + "stream": args.stream, + } + + if profile: + run_profile( + base_url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage + ) + + response = requests.post( + f"{base_url}/generate", + json=json_data, + stream=args.stream, + ) + + if args.stream: + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + ret = json.loads(chunk[5:].strip("\n")) + else: + ret = response.json() + ret = ret[0] + + if response.status_code != 200: + print(ret) + return -1 + + return ret["text"] + + +def send_mixed(args, batch_size: int): + num_long_prompt = 0 if batch_size <= 10 else random.randint(1, 10) + num_prompt_1 = random.randint(1, batch_size - num_long_prompt) + num_prompt_2 = batch_size - num_prompt_1 - num_long_prompt + + json_data = { + "text": [PROMPT_1] * num_prompt_1 + + [PROMPT_2] * num_prompt_2 + + [LONG_PROMPT] * num_long_prompt, + "sampling_params": { + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "frequency_penalty": args.frequency_penalty, + "presence_penalty": args.presence_penalty, + }, + "return_logprob": args.return_logprob, + "stream": args.stream, + } + + response = requests.post( + f"http://{args.host}:{args.port}/generate", + json=json_data, + stream=args.stream, + ) + ret = response.json() + if response.status_code != 200: + print(ret) + return -1, -1, -1 + + prompt_1_ret = [ret[i]["text"] for i in range(num_prompt_1)] + prompt_2_ret = [ + ret[i]["text"] for i in range(num_prompt_1, num_prompt_1 + num_prompt_2) + ] + long_prompt_ret = [ + ret[i]["text"] + for i in range( + num_prompt_1 + num_prompt_2, num_prompt_1 + num_prompt_2 + num_long_prompt + ) + ] + + return prompt_1_ret, prompt_2_ret, long_prompt_ret + + +def send_prefix(args, batch_size: int, prompts: List[str]): + requests.post(f"http://{args.host}:{args.port}/flush_cache") + + batch_data = [] + sampled_indices = [] + for _ in range(batch_size): + sampled_index = random.randint(0, len(prompts) - 1) + sampled_indices.append(sampled_index) + batch_data.append(prompts[sampled_index]) + + json_data = { + "text": batch_data, + "sampling_params": { + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "frequency_penalty": args.frequency_penalty, + "presence_penalty": args.presence_penalty, + }, + "return_logprob": args.return_logprob, + "stream": args.stream, + } + + response = requests.post( + f"http://{args.host}:{args.port}/generate", + json=json_data, + stream=args.stream, + ) + ret = response.json() + if response.status_code != 200: + print(ret) + return -1, -1, -1 + + ret_dict = {i: [] for i in range(len(prompts))} + for i in range(batch_size): + ret_dict[sampled_indices[i]].append(ret[i]["text"]) + + return ret_dict + + +def test_deterministic(args): + # First do some warmups + for i in range(3): + send_single(args, 16, args.profile) + + if args.test_mode == "single": + # In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials. + texts = [] + for i in range(1, args.n_trials + 1): + batch_size = i + text = send_single(args, batch_size, args.profile) + text = text.replace("\n", " ") + print(f"Trial {i} with batch size {batch_size}: {text}") + texts.append(text) + + print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}") + elif args.test_mode == "mixed": + # In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials. + output_prompt_1 = [] + output_prompt_2 = [] + output_long_prompt = [] + for i in range(1, args.n_trials + 1): + batch_size = i + ret_prompt_1, ret_prompt_2, ret_long_prompt = send_mixed(args, batch_size) + output_prompt_1.extend(ret_prompt_1) + output_prompt_2.extend(ret_prompt_2) + output_long_prompt.extend(ret_long_prompt) + + print( + f"Testing Trial {i} with batch size {batch_size}, number of prompt 1: {len(ret_prompt_1)}, number of prompt 2: {len(ret_prompt_2)}, number of long prompt: {len(ret_long_prompt)}" + ) + + print( + f"Prompt 1: total samples: {len(output_prompt_1)}, Unique samples: {len(set(output_prompt_1))}" + ) + print( + f"Prompt 2: total samples: {len(output_prompt_2)}, Unique samples: {len(set(output_prompt_2))}" + ) + print( + f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}" + ) + + elif args.test_mode == "prefix": + # In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix. + len_prefix = [1, 511, 2048, 4097] + num_prompts = len(len_prefix) + outputs = {i: [] for i in range(4)} + prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)] + for i in range(1, args.n_trials + 1): + batch_size = i + ret_dict = send_prefix(args, batch_size, prompts) + msg = f"Testing Trial {i} with batch size {batch_size}," + for i in range(num_prompts): + msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])}," + print(msg) + for i in range(num_prompts): + outputs[i].extend(ret_dict[i]) + + for i in range(num_prompts): + print( + f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}" + ) + + else: + raise ValueError(f"Invalid test mode: {args.test_mode}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + BenchArgs.add_cli_args(parser) + args = parser.parse_args() + + test_deterministic(args)