diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md index cdc53da61..77d7c9f82 100644 --- a/docs/references/sampling_params.md +++ b/docs/references/sampling_params.md @@ -32,6 +32,20 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True + + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None + # LoRA related + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + + # Session info for continual prompting + session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None ``` The `sampling_params` follows this format @@ -90,6 +104,14 @@ repetition_penalty: float = 1.0, # difficult to infer the correct token ID by given `stop` strings. # Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty. min_new_tokens: int = 0, + + +## Custom Parameters for Custom Logit Processor. +# A dictionary of custom parameters for the custom logit processor. +# The custom logit processor takes a list of dictionaries as input, where each +# dictionary is the custom parameters for one token in a batch of the input. +# See also python/sglang/srt/sampling/custom_logit_processor.py +custom_params: Optional[Dict[str, Any]] = None, ``` ## Examples @@ -253,3 +275,49 @@ response = requests.post( ) print(response.json()) ``` +### Custom Logit Processor +Launch a server with `--enable-custom-logit-processor` flag on. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor +``` + +Define a custom logit processor that will always sample a specific token id. +```python +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor + +class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + # Check that the number of logits matches the number of custom parameters + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits +``` + +Send a request +```python +import requests + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "The capital of France is", + "custom_logit_processor": DeterministicLogitProcessor().to_str(), + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 32, + "custom_params": {"token_id": 5}, + }, + }, +) +print(response.json()) +``` diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 918323983..eee9b6722 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -69,8 +69,10 @@ class GenerateReqInput: # Session info for continual prompting session_params: Optional[Union[List[Dict], Dict]] = None - # Custom logit processor (serialized function) - custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None def normalize_batch_and_arguments(self): if ( @@ -248,8 +250,9 @@ class TokenizedGenerateReqInput: # Session info for continual prompting session_params: Optional[SessionParams] = None - # Custom logit processor (serialized function) - # TODO (hpguo): Add an example and update doc string here + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. custom_logit_processor: Optional[str] = None