From 1aea19f64b06cee64368a6f0488af1fb2a39e328 Mon Sep 17 00:00:00 2001 From: Rin Intachuen <113603872+RinRin-32@users.noreply.github.com> Date: Mon, 25 Nov 2024 19:35:04 -0500 Subject: [PATCH] Input_embeds support (#2052) --- docs/references/sampling_params.md | 10 +- python/sglang/srt/managers/io_struct.py | 30 ++++- python/sglang/srt/managers/schedule_batch.py | 21 ++++ python/sglang/srt/managers/scheduler.py | 8 ++ .../sglang/srt/managers/tokenizer_manager.py | 17 ++- .../srt/model_executor/forward_batch_info.py | 4 + .../sglang/srt/model_executor/model_runner.py | 14 ++- test/srt/run_suite.py | 1 + test/srt/test_input_embeddings.py | 114 ++++++++++++++++++ 9 files changed, 204 insertions(+), 15 deletions(-) create mode 100644 test/srt/test_input_embeddings.py diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md index d144b059d..147e6c2ab 100644 --- a/docs/references/sampling_params.md +++ b/docs/references/sampling_params.md @@ -11,21 +11,23 @@ The `/generate` endpoint accepts the following arguments in the JSON format. class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None - # The token ids for text; one can either specify text or input_ids. + # The token ids for text; one can specify either text or input_ids input_ids: Optional[Union[List[List[int]], List[int]]] = None + # The embeddings for input_ids; one can specify either text or input_ids or input_embeds. + input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None # The image input. It can be a file name, a url, or base64 encoded string. # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None # The sampling_params. See descriptions below. - sampling_params: Union[List[Dict], Dict] = None + sampling_params: Optional[Union[List[Dict], Dict]] = None # The request id. rid: Optional[Union[List[str], str]] = None # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None - # The start location of the prompt for return_logprob. + # If return logprobs, the start location in the prompt for returning logprobs. # By default, this value is "-1", which means it will only return logprobs for output tokens. logprob_start_len: Optional[Union[List[int], int]] = None - # The number of top logprobs to return. + # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: Optional[Union[List[int], int]] = None # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9541b2d18..8b1f88fa2 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -29,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None - # The token ids for text; one can either specify text or input_ids. + # The token ids for text; one can specify either text or input_ids input_ids: Optional[Union[List[List[int]], List[int]]] = None + # The embeddings for input_ids; one can specify either text or input_ids or input_embeds. + input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None # The image input. It can be a file name, a url, or base64 encoded string. # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None @@ -60,10 +62,16 @@ class GenerateReqInput: ] = None def normalize_batch_and_arguments(self): - if (self.text is None and self.input_ids is None) or ( - self.text is not None and self.input_ids is not None + if ( + self.text is None and self.input_ids is None and self.input_embeds is None + ) or ( + self.text is not None + and self.input_ids is not None + and self.input_embeds is not None ): - raise ValueError("Either text or input_ids should be provided.") + raise ValueError( + "Either text, input_ids or input_embeds should be provided." + ) # Derive the batch size if self.text is not None: @@ -73,13 +81,21 @@ class GenerateReqInput: else: self.is_single = False self.batch_size = len(self.text) - else: + self.input_embeds = None + elif self.input_ids is not None: if isinstance(self.input_ids[0], int): self.is_single = True self.batch_size = 1 else: self.is_single = False self.batch_size = len(self.input_ids) + self.input_embeds = None + else: + if isinstance(self.input_embeds[0][0], float): + self.is_single = True + self.batch_size = 1 + else: + self.batch_size = len(self.input_embeds) # Handle parallel sampling # When parallel sampling is used, we always treat the input as a batch. @@ -202,6 +218,8 @@ class TokenizedGenerateReqInput: # LoRA related lora_path: Optional[str] = None # None means just use the base model + # The input embeds + input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None # Session id info for continual prompting session_id: Optional[str] = None @@ -218,6 +236,8 @@ class EmbeddingReqInput: rid: Optional[Union[List[str], str]] = None # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None + # Dummy input embeds for compatibility + input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 97dec49c2..4d1bbece2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -178,6 +178,7 @@ class Req: origin_input_ids: Tuple[int], sampling_params: SamplingParams, lora_path: Optional[str] = None, + input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, ): # Input and output info @@ -191,6 +192,7 @@ class Req: self.sampling_params = sampling_params self.lora_path = lora_path + self.input_embeds = input_embeds # Memory pool info self.req_pool_idx = None @@ -448,6 +450,7 @@ class ScheduleBatch: # Batched arguments to model runner input_ids: torch.Tensor = None + input_embeds: torch.Tensor = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None # The output locations of the KV cache @@ -631,6 +634,9 @@ class ScheduleBatch: req_pool_indices = self.alloc_req_slots(bs) out_cache_loc = self.alloc_token_slots(extend_num_tokens) + input_embeds = [] + + pt = 0 for i, req in enumerate(reqs): already_computed = ( req.extend_logprob_start_len + 1 + req.cached_tokens @@ -649,6 +655,11 @@ class ScheduleBatch: (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices ) + # If input_embeds are available, store them + if req.input_embeds is not None: + # If req.input_embeds is already a list, append its content directly + input_embeds.extend(req.input_embeds) # Use extend to avoid nesting + # Compute the relative logprob_start_len in an extend batch if req.logprob_start_len >= pre_len: extend_logprob_start_len = min( @@ -671,6 +682,12 @@ class ScheduleBatch: self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( self.device, non_blocking=True ) + self.input_embeds = ( + torch.tensor(input_embeds).to(self.device, non_blocking=True) + if input_embeds + else None + ) + self.out_cache_loc = out_cache_loc self.seq_lens_sum = sum(seq_lens) @@ -1053,6 +1070,7 @@ class ScheduleBatch: encoder_out_cache_loc=self.encoder_out_cache_loc, lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, + input_embeds=self.input_embeds, ) def copy(self): @@ -1123,6 +1141,9 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo + # The input Embeds + input_embeds: Optional[torch.tensor] = None + @triton.jit def write_req_to_token_pool_triton( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2ae705422..0994aeb59 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -526,12 +526,20 @@ class Scheduler: recv_req: TokenizedGenerateReqInput, ): if recv_req.session_id is None or recv_req.session_id not in self.sessions: + # Check if input_embeds is present and create dummy input_ids + if recv_req.input_embeds is not None: + # Generate fake input_ids based on the length of input_embeds + seq_length = len(recv_req.input_embeds) + fake_input_ids = [1] * seq_length + recv_req.input_ids = fake_input_ids + req = Req( recv_req.rid, recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, lora_path=recv_req.lora_path, + input_embeds=recv_req.input_embeds, ) req.tokenizer = self.tokenizer if recv_req.session_id is not None: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index be58c939f..001ecc1eb 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -201,8 +201,18 @@ class TokenizerManager: ): """Tokenize one request.""" # Tokenize + input_embeds = None input_text = obj.text - if obj.input_ids is None: + if obj.input_embeds is not None: + if not self.server_args.disable_radix_cache: + raise ValueError( + "input_embeds is provided while disable_radix_cache is False. " + "Please add `--disable-radix-cach` when you launch the server " + "if you want to use input_embeds as inputs." + ) + input_embeds = obj.input_embeds + input_ids = obj.input_ids + elif obj.input_ids is None: input_ids = self.tokenizer.encode(input_text) else: input_ids = obj.input_ids @@ -219,7 +229,7 @@ class TokenizerManager: session_id = obj.session[0] if obj.session else None session_rid = obj.session[1] if obj.session else None - if len(input_ids) >= self.context_len: + if obj.input_ids is not None and len(input_ids) >= self.context_len: raise ValueError( f"The input ({len(input_ids)} tokens) is longer than the " f"model's context length ({self.context_len} tokens)." @@ -242,7 +252,8 @@ class TokenizerManager: logprob_start_len, top_logprobs_num, obj.stream, - obj.lora_path, + lora_path=obj.lora_path, + input_embeds=input_embeds, session_id=session_id, session_rid=session_rid, ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e1e27752d..2fe841bb2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -130,6 +130,9 @@ class ForwardBatch: # For LoRA lora_paths: Optional[List[str]] = None + # For input embeddings + input_embeds: Optional[torch.tensor] = None + # Sampling info sampling_info: SamplingBatchInfo = None @@ -231,6 +234,7 @@ class ForwardBatch: can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + input_embeds=batch.input_embeds, ) if ret.global_num_tokens is not None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b83271f43..7c1c51a8f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -606,9 +606,17 @@ class ModelRunner: def forward_extend(self, forward_batch: ForwardBatch): self.attn_backend.init_forward_metadata(forward_batch) if self.is_generation: - return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) + if forward_batch.input_embeds is None: + return self.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + else: + return self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + input_embeds=forward_batch.input_embeds.bfloat16(), + ) else: # Only embedding models have get_embedding parameter return self.model.forward( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 27fe6d7d3..560c77c2a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -14,6 +14,7 @@ suites = { "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", + "test_input_embeddings.py", "test_json_constrained.py", "test_large_max_new_tokens.py", "test_metrics.py", diff --git a/test/srt/test_input_embeddings.py b/test/srt/test_input_embeddings.py new file mode 100644 index 000000000..b57b61dad --- /dev/null +++ b/test/srt/test_input_embeddings.py @@ -0,0 +1,114 @@ +import json +import unittest + +import requests +from transformers import AutoModelForCausalLM, AutoTokenizer + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestInputEmbeds(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model) + cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--disable-radix"], + ) + cls.texts = [ + "The capital of France is", + "What is the best time of year to visit Japan for cherry blossoms?", + ] + + def generate_input_embeddings(self, text): + """Generate input embeddings for a given text.""" + input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"] + embeddings = self.ref_model.get_input_embeddings()(input_ids) + return embeddings.squeeze().tolist() # Convert tensor to a list for API use + + def send_request(self, payload): + """Send a POST request to the API and return the response.""" + response = requests.post( + self.base_url + "/generate", + json=payload, + timeout=30, # Set a reasonable timeout for the API request + ) + if response.status_code == 200: + return response.json() + return { + "error": f"Request failed with status {response.status_code}: {response.text}" + } + + def test_text_based_response(self): + """Print API response using text-based input.""" + for text in self.texts: + payload = { + "model": self.model, + "text": text, + "sampling_params": {"temperature": 0, "max_new_tokens": 50}, + } + response = self.send_request(payload) + print( + f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}" + ) + + def test_embedding_based_response(self): + """Print API response using input embeddings.""" + for text in self.texts: + embeddings = self.generate_input_embeddings(text) + payload = { + "model": self.model, + "input_embeds": embeddings, + "sampling_params": {"temperature": 0, "max_new_tokens": 50}, + } + response = self.send_request(payload) + print( + f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}" + ) + + def test_compare_text_vs_embedding(self): + """Print responses for both text-based and embedding-based inputs.""" + for text in self.texts: + # Text-based payload + text_payload = { + "model": self.model, + "text": text, + "sampling_params": {"temperature": 0, "max_new_tokens": 50}, + } + # Embedding-based payload + embeddings = self.generate_input_embeddings(text) + embed_payload = { + "model": self.model, + "input_embeds": embeddings, + "sampling_params": {"temperature": 0, "max_new_tokens": 50}, + } + # Get responses + text_response = self.send_request(text_payload) + embed_response = self.send_request(embed_payload) + # Print responses + print( + f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n" + ) + print( + f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}" + ) + self.assertEqual(text_response["text"], embed_response["text"]) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + +if __name__ == "__main__": + unittest.main()