diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index 91c1d9e31..e1c1f7270 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns. """ import asyncio +import copy import dataclasses import logging import os @@ -11,6 +12,7 @@ import signal import sys import threading import time +import uuid from typing import Any, Dict, List, Optional, Union import grpc @@ -79,11 +81,9 @@ class GrpcReqState: last_completion_tokens: int = 1 # Streaming state - last_output_offset: int = 0 stream_finished: bool = False - # Output accumulation - text: str = "" + # Token accumulation (for non-streaming) output_ids: List[int] = dataclasses.field(default_factory=list) input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) @@ -139,8 +139,6 @@ class GrpcRequestManager: self.is_pause_cond = asyncio.Condition() # Metrics - self.request_counter = 0 - self.request_counter_lock = asyncio.Lock() self.last_receive_tstamp = time.time() # Crash dump for debugging @@ -158,22 +156,133 @@ class GrpcRequestManager: obj: TokenizedGenerateReqInput, request_id: Optional[str] = None, grpc_context: Optional[grpc.aio.ServicerContext] = None, - ) -> asyncio.Queue: + ): """ - Submit a generation request to the scheduler. - Returns a queue for streaming outputs. + Submit a generation request to the scheduler with n>1 parallel sampling support. + + This method implements the same two-phase approach as tokenizer_manager.py: + 1. Phase 1: Send prefix caching request (max_new_tokens=0) + 2. Phase 2: Send n generation requests that reuse the cached prefix + + Yields individual responses for streaming, or aggregated responses for non-streaming. """ + n = getattr(obj.sampling_params, "n", 1) + + if n <= 1: + async for response in self._handle_single_request( + obj, request_id, grpc_context + ): + yield response + return + + # N>1 handling - two-phase approach + logger.debug(f"Multiple sampling request (n={n}), using two-phase approach") + + # Generate base request ID if not provided + if request_id is None: + base_request_id = f"grpc-{uuid.uuid4().hex}" + else: + base_request_id = request_id + + # Phase 1: Cache the common prefix + logger.debug(f"Phase 1: Caching prefix for request {base_request_id}") + prefix_obj = copy.copy(obj) + prefix_obj.sampling_params = copy.copy(obj.sampling_params) + prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only + prefix_obj.sampling_params.n = 1 # Don't replicate prefix request + + # Send prefix caching request and consume response + async for _ in self._handle_single_request( + prefix_obj, f"{base_request_id}-prefix", grpc_context + ): + # Consume prefix response (usually just one chunk with finish_reason) + pass + + logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}") + + # Phase 2: Generate n parallel requests + logger.debug(f"Phase 2: Generating {n} parallel requests") + generators = [] + request_ids = [] + + for i in range(n): + # Create individual generation request + gen_obj = copy.copy(obj) + gen_obj.sampling_params = copy.copy(obj.sampling_params) + gen_obj.sampling_params.n = 1 # Each request generates 1 response + + gen_request_id = f"{base_request_id}-{i}" + request_ids.append(gen_request_id) + + # Start generation request + generators.append( + self._handle_single_request(gen_obj, gen_request_id, grpc_context) + ) + + # Handle response aggregation + is_stream = getattr(obj, "stream", False) + + if not is_stream: + # Non-streaming: collect all responses and return as batch + logger.debug(f"Non-streaming mode: collecting {n} responses") + responses = [] + for generator in generators: + async for response in generator: + responses.append(response) + yield responses # Return all responses as a batch + else: + # Streaming mode: multiplex responses with index for ordering + logger.debug(f"Streaming mode: multiplexing {n} streams") + rid_to_index = {rid: i for i, rid in enumerate(request_ids)} + + # Create async tasks for all generators + task_map = {} + for generator in generators: + task = asyncio.create_task(generator.__anext__()) + task_map[task] = generator + + # Process responses as they arrive + while task_map: + done, _ = await asyncio.wait( + task_map.keys(), return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + generator = task_map.pop(task) + try: + response = await task + + # Add index for client-side ordering + if isinstance(response, dict) and "meta_info" in response: + response_rid = response["meta_info"].get("id", "") + if response_rid in rid_to_index: + response["index"] = rid_to_index[response_rid] + + yield response + + # Create next task for this generator + next_task = asyncio.create_task(generator.__anext__()) + task_map[next_task] = generator + + except StopAsyncIteration: + # This generator is finished + pass + + async def _handle_single_request( + self, + obj: TokenizedGenerateReqInput, + request_id: Optional[str] = None, + grpc_context: Optional[grpc.aio.ServicerContext] = None, + ): + """Handle a single request - core implementation without n>1 logic.""" # Generate request ID if not provided if request_id is None: - async with self.request_counter_lock: - request_id = f"grpc-{self.request_counter}" - self.request_counter += 1 + request_id = f"grpc-{uuid.uuid4().hex}" obj.rid = request_id + # Create and register request state # TODO: support log_request - - # Create request state state = GrpcReqState( request_id=request_id, grpc_context=grpc_context, @@ -189,19 +298,51 @@ class GrpcRequestManager: state.session_id = obj.session_params.session_id state.is_session_request = True - # Register state self.rid_to_state[request_id] = state self.record_request_for_crash_dump(obj) - # Send to scheduler via ZMQ try: + # Send to scheduler - let exceptions bubble up to grpc_server.py await self._send_to_scheduler(obj) - except Exception as e: - # Clean up on failure - del self.rid_to_state[request_id] - raise RuntimeError(f"Failed to send request to scheduler: {e}") - return state.out_queue + is_stream = getattr(obj, "stream", False) + + while True: + # Client cancelled - notify scheduler and exit + if grpc_context and grpc_context.cancelled(): + await self.abort_request(request_id) + return + + try: + response = await asyncio.wait_for(state.out_queue.get(), timeout=4) + + if is_stream: + yield response + + # Non-streaming: yield final response with accumulated tokens from state + if isinstance(response, dict) and response.get("finished", False): + if not is_stream: + final_response = response.copy() + final_response["token_ids"] = state.output_ids + yield final_response + break + + except asyncio.TimeoutError: + # Timeout waiting for response - abort and cleanup + logger.warning( + f"Timeout waiting for response for request {request_id}" + ) + await self.abort_request(request_id) + return + + finally: + # Always clean up request state when exiting + self._cleanup_request_state(request_id) + + def _cleanup_request_state(self, request_id: str): + """Clean up local request state (does not notify scheduler).""" + if request_id in self.rid_to_state: + del self.rid_to_state[request_id] async def embedding_request( self, @@ -214,9 +355,7 @@ class GrpcRequestManager: """ # Generate request ID if not provided if request_id is None: - async with self.request_counter_lock: - request_id = f"grpc-embed-{self.request_counter}" - self.request_counter += 1 + request_id = f"grpc-embed-{uuid.uuid4().hex}" obj.rid = request_id @@ -355,7 +494,6 @@ class GrpcRequestManager: # Extract output for this request output_data = { "request_id": rid, - "text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "", "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [], "finished": batch_out.finished_reasons[i] is not None, "meta_info": { @@ -367,6 +505,9 @@ class GrpcRequestManager: if batch_out.completion_tokens else 0 ), + "cached_tokens": ( + batch_out.cached_tokens[i] if batch_out.cached_tokens else 0 + ), "finish_reason": ( str(batch_out.finished_reasons[i]) if batch_out.finished_reasons[i] @@ -389,15 +530,10 @@ class GrpcRequestManager: ), } - # Update state - if output_data["text"]: - state.text += output_data["text"][state.last_output_offset :] - state.last_output_offset = len(output_data["text"]) - + # Update state for accumulation if output_data["token_ids"]: state.output_ids.extend(output_data["token_ids"]) - # Send to output queue await state.out_queue.put(output_data) # Handle completion diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index fa1b1143d..2c6bf62c7 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) # Convert gRPC request to internal format tokenized_req = self._convert_generate_request(request) - # Submit to request manager - output_queue = await self.request_manager.generate_request( + # Submit to request manager (automatically handles n>1) + response_generator = self.request_manager.generate_request( obj=tokenized_req, request_id=request.request_id, grpc_context=context, ) - # Stream outputs - while True: - try: - # Get output with timeout - output = await asyncio.wait_for(output_queue.get(), timeout=4) - - # Check for errors + async for output in response_generator: + # Handle batch responses (for n>1 non-streaming) + if isinstance(output, list): + for batch_output in output: + if "error" in batch_output: + yield sglang_scheduler_pb2.GenerateResponse( + request_id=request.request_id, + error=sglang_scheduler_pb2.GenerateError( + message=batch_output["error"], + http_status_code=( + "500" if "abort" not in batch_output else "499" + ), + ), + ) + else: + # All non-error batch outputs are final responses + yield self._create_completion_response( + request.request_id, batch_output + ) + else: + # Handle single response (for streaming or n=1 non-streaming) if "error" in output: yield sglang_scheduler_pb2.GenerateResponse( request_id=request.request_id, @@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ), ), ) - break - - # Check if finished - if output.get("finished", False): - # Send completion + elif output.get("finished", False): yield self._create_completion_response( request.request_id, output ) - break else: - # Send chunk yield self._create_chunk_response(request.request_id, output) - except asyncio.TimeoutError: - # Check if context is still active - if context.cancelled(): - # Abort the request - await self.request_manager.abort_request(request.request_id) - break - continue - except Exception as e: logger.error(f"Generate failed: {e}\n{get_exception_traceback()}") yield sglang_scheduler_pb2.GenerateResponse( @@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) return_logprob=grpc_req.return_logprob, logprob_start_len=grpc_req.logprob_start_len or -1, top_logprobs_num=grpc_req.top_logprobs_num or 0, - stream=True, # Always stream for gRPC + stream=grpc_req.stream or False, lora_path=grpc_req.lora_id if grpc_req.lora_id else None, token_ids_logprob=( list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None @@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) return sglang_scheduler_pb2.GenerateResponse( request_id=request_id, chunk=sglang_scheduler_pb2.GenerateStreamChunk( - token_id=output["token_ids"][-1] if output.get("token_ids") else 0, + token_ids=output.get("token_ids", []), prompt_tokens=meta_info.get("prompt_tokens", 0), completion_tokens=meta_info.get("completion_tokens", 0), - cached_tokens=0, + cached_tokens=meta_info.get("cached_tokens", 0), ), ) diff --git a/python/sglang/srt/grpc/sglang_scheduler.proto b/python/sglang/srt/grpc/sglang_scheduler.proto index e4638e7a9..f52f50d2a 100644 --- a/python/sglang/srt/grpc/sglang_scheduler.proto +++ b/python/sglang/srt/grpc/sglang_scheduler.proto @@ -122,6 +122,9 @@ message GenerateRequest { // For load balancing int32 dp_balance_id = 17; + + // Whether client wants streaming response + bool stream = 18; } message TokenizedInput { @@ -163,8 +166,8 @@ message GenerateResponse { } message GenerateStreamChunk { - // Generated token - int32 token_id = 1; + // Generated tokens (incremental chunk) + repeated int32 token_ids = 1; // Cumulative counts int32 prompt_tokens = 2; diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py index 8b05bf3fc..1142104aa 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.py +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py @@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc9\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xba\x01\n\x13GenerateStreamChunk\x12\x10\n\x08token_id\x18\x01 \x01(\x05\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x31\n\x08logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12K\n\rfinish_reason\x18\x02 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x35\n\x0c\x61ll_logprobs\x18\x06 \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc9\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xbb\x01\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x31\n\x08logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12K\n\rfinish_reason\x18\x02 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x35\n\x0c\x61ll_logprobs\x18\x06 \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -45,65 +45,65 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_DISAGGREGATEDPARAMS']._serialized_start=828 _globals['_DISAGGREGATEDPARAMS']._serialized_end=921 _globals['_GENERATEREQUEST']._serialized_start=924 - _globals['_GENERATEREQUEST']._serialized_end=1541 - _globals['_TOKENIZEDINPUT']._serialized_start=1543 - _globals['_TOKENIZEDINPUT']._serialized_end=1601 - _globals['_MULTIMODALINPUTS']._serialized_start=1604 - _globals['_MULTIMODALINPUTS']._serialized_end=1815 - _globals['_GENERATERESPONSE']._serialized_start=1818 - _globals['_GENERATERESPONSE']._serialized_end=2045 - _globals['_GENERATESTREAMCHUNK']._serialized_start=2048 - _globals['_GENERATESTREAMCHUNK']._serialized_end=2234 - _globals['_GENERATECOMPLETE']._serialized_start=2237 - _globals['_GENERATECOMPLETE']._serialized_end=2622 - _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2546 - _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2622 - _globals['_GENERATEERROR']._serialized_start=2624 - _globals['_GENERATEERROR']._serialized_end=2699 - _globals['_LOGPROBS']._serialized_start=2702 - _globals['_LOGPROBS']._serialized_end=2834 - _globals['_TOPLOGPROBS']._serialized_start=2836 - _globals['_TOPLOGPROBS']._serialized_end=2905 - _globals['_HIDDENSTATES']._serialized_start=2907 - _globals['_HIDDENSTATES']._serialized_end=2970 - _globals['_EMBEDREQUEST']._serialized_start=2973 - _globals['_EMBEDREQUEST']._serialized_end=3303 - _globals['_EMBEDRESPONSE']._serialized_start=3306 - _globals['_EMBEDRESPONSE']._serialized_end=3463 - _globals['_EMBEDCOMPLETE']._serialized_start=3466 - _globals['_EMBEDCOMPLETE']._serialized_end=3629 - _globals['_EMBEDDING']._serialized_start=3631 - _globals['_EMBEDDING']._serialized_end=3673 - _globals['_EMBEDERROR']._serialized_start=3675 - _globals['_EMBEDERROR']._serialized_end=3735 - _globals['_HEALTHCHECKREQUEST']._serialized_start=3737 - _globals['_HEALTHCHECKREQUEST']._serialized_end=3815 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=3817 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=3872 - _globals['_ABORTREQUEST']._serialized_start=3874 - _globals['_ABORTREQUEST']._serialized_end=3924 - _globals['_ABORTRESPONSE']._serialized_start=3926 - _globals['_ABORTRESPONSE']._serialized_end=3975 - _globals['_LOADLORAREQUEST']._serialized_start=3977 - _globals['_LOADLORAREQUEST']._serialized_end=4050 - _globals['_LOADLORARESPONSE']._serialized_start=4052 - _globals['_LOADLORARESPONSE']._serialized_end=4124 - _globals['_UNLOADLORAREQUEST']._serialized_start=4126 - _globals['_UNLOADLORAREQUEST']._serialized_end=4165 - _globals['_UNLOADLORARESPONSE']._serialized_start=4167 - _globals['_UNLOADLORARESPONSE']._serialized_end=4221 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4223 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4342 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4344 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4401 - _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4403 - _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4448 - _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4450 - _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4516 - _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4518 - _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4583 - _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4585 - _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4645 - _globals['_SGLANGSCHEDULER']._serialized_start=4648 - _globals['_SGLANGSCHEDULER']._serialized_end=5030 + _globals['_GENERATEREQUEST']._serialized_end=1557 + _globals['_TOKENIZEDINPUT']._serialized_start=1559 + _globals['_TOKENIZEDINPUT']._serialized_end=1617 + _globals['_MULTIMODALINPUTS']._serialized_start=1620 + _globals['_MULTIMODALINPUTS']._serialized_end=1831 + _globals['_GENERATERESPONSE']._serialized_start=1834 + _globals['_GENERATERESPONSE']._serialized_end=2061 + _globals['_GENERATESTREAMCHUNK']._serialized_start=2064 + _globals['_GENERATESTREAMCHUNK']._serialized_end=2251 + _globals['_GENERATECOMPLETE']._serialized_start=2254 + _globals['_GENERATECOMPLETE']._serialized_end=2639 + _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2563 + _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2639 + _globals['_GENERATEERROR']._serialized_start=2641 + _globals['_GENERATEERROR']._serialized_end=2716 + _globals['_LOGPROBS']._serialized_start=2719 + _globals['_LOGPROBS']._serialized_end=2851 + _globals['_TOPLOGPROBS']._serialized_start=2853 + _globals['_TOPLOGPROBS']._serialized_end=2922 + _globals['_HIDDENSTATES']._serialized_start=2924 + _globals['_HIDDENSTATES']._serialized_end=2987 + _globals['_EMBEDREQUEST']._serialized_start=2990 + _globals['_EMBEDREQUEST']._serialized_end=3320 + _globals['_EMBEDRESPONSE']._serialized_start=3323 + _globals['_EMBEDRESPONSE']._serialized_end=3480 + _globals['_EMBEDCOMPLETE']._serialized_start=3483 + _globals['_EMBEDCOMPLETE']._serialized_end=3646 + _globals['_EMBEDDING']._serialized_start=3648 + _globals['_EMBEDDING']._serialized_end=3690 + _globals['_EMBEDERROR']._serialized_start=3692 + _globals['_EMBEDERROR']._serialized_end=3752 + _globals['_HEALTHCHECKREQUEST']._serialized_start=3754 + _globals['_HEALTHCHECKREQUEST']._serialized_end=3832 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=3834 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=3889 + _globals['_ABORTREQUEST']._serialized_start=3891 + _globals['_ABORTREQUEST']._serialized_end=3941 + _globals['_ABORTRESPONSE']._serialized_start=3943 + _globals['_ABORTRESPONSE']._serialized_end=3992 + _globals['_LOADLORAREQUEST']._serialized_start=3994 + _globals['_LOADLORAREQUEST']._serialized_end=4067 + _globals['_LOADLORARESPONSE']._serialized_start=4069 + _globals['_LOADLORARESPONSE']._serialized_end=4141 + _globals['_UNLOADLORAREQUEST']._serialized_start=4143 + _globals['_UNLOADLORAREQUEST']._serialized_end=4182 + _globals['_UNLOADLORARESPONSE']._serialized_start=4184 + _globals['_UNLOADLORARESPONSE']._serialized_end=4238 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4240 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4359 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4361 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4418 + _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4420 + _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4465 + _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4467 + _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4533 + _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4535 + _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4600 + _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4602 + _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4662 + _globals['_SGLANGSCHEDULER']._serialized_start=4665 + _globals['_SGLANGSCHEDULER']._serialized_end=5047 # @@protoc_insertion_point(module_scope) diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi index bf383f127..561e418e8 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi @@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message): def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... class GenerateRequest(_message.Message): - __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id") + __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream") REQUEST_ID_FIELD_NUMBER: _ClassVar[int] TOKENIZED_FIELD_NUMBER: _ClassVar[int] MM_INPUTS_FIELD_NUMBER: _ClassVar[int] @@ -101,6 +101,7 @@ class GenerateRequest(_message.Message): LORA_ID_FIELD_NUMBER: _ClassVar[int] DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int] + STREAM_FIELD_NUMBER: _ClassVar[int] request_id: str tokenized: TokenizedInput mm_inputs: MultimodalInputs @@ -118,7 +119,8 @@ class GenerateRequest(_message.Message): lora_id: str data_parallel_rank: int dp_balance_id: int - def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ... + stream: bool + def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ... class TokenizedInput(_message.Message): __slots__ = ("original_text", "input_ids") @@ -161,20 +163,20 @@ class GenerateResponse(_message.Message): def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... class GenerateStreamChunk(_message.Message): - __slots__ = ("token_id", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states") - TOKEN_ID_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states") + TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] LOGPROBS_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] - token_id: int + token_ids: _containers.RepeatedScalarFieldContainer[int] prompt_tokens: int completion_tokens: int cached_tokens: int logprobs: LogProbs hidden_states: _containers.RepeatedScalarFieldContainer[float] - def __init__(self, token_id: _Optional[int] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ... + def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ... class GenerateComplete(_message.Message): __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states") diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 0b87f85b3..d0f3c4c26 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -103,6 +103,7 @@ impl SglangSchedulerClient { logprob_start_len: -1, top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, return_hidden_states: body.return_hidden_states, + stream: body.stream, ..Default::default() }; @@ -367,14 +368,14 @@ mod tests { #[test] fn test_generate_stream_chunk() { let chunk = proto::GenerateStreamChunk { - token_id: 1234, + token_ids: vec![1234, 5678], prompt_tokens: 5, completion_tokens: 2, cached_tokens: 3, ..Default::default() }; - assert_eq!(chunk.token_id, 1234); + assert_eq!(chunk.token_ids, vec![1234, 5678]); assert_eq!(chunk.prompt_tokens, 5); assert_eq!(chunk.completion_tokens, 2); assert_eq!(chunk.cached_tokens, 3); diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index b0e8b92c5..2892caec2 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -122,6 +122,9 @@ message GenerateRequest { // For load balancing int32 dp_balance_id = 17; + + // Whether client wants streaming response + bool stream = 18; } message TokenizedInput { @@ -163,8 +166,8 @@ message GenerateResponse { } message GenerateStreamChunk { - // Generated token - int32 token_id = 1; + // Generated tokens (incremental chunk) + repeated int32 token_ids = 1; // Cumulative counts int32 prompt_tokens = 2; diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index f4f4337b3..f630100c6 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -203,6 +203,7 @@ impl GrpcRouter { debug!("Selected worker: {}", worker.url()); // Step 2: Get gRPC client for worker (fail fast if can't connect) + // TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here) let client = match self.get_or_create_grpc_client(worker.url()).await { Ok(c) => c, Err(e) => { @@ -249,7 +250,7 @@ impl GrpcRouter { // Step 6: Build the base gRPC request let request_id = format!("chatcmpl-{}", Uuid::new_v4()); - let base_request = match client.build_generate_request( + let request = match client.build_generate_request( request_id, body, processed_messages.text.clone(), @@ -268,11 +269,11 @@ impl GrpcRouter { } }; + // Step 7: Handle streaming vs non-streaming if body.stream { - self.handle_streaming_chat(client, base_request, body).await + self.handle_streaming_chat(client, request, body).await } else { - self.handle_non_streaming_chat(client, base_request, body) - .await + self.handle_non_streaming_chat(client, request, body).await } }