diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index 5d94fee6c..61c1af24f 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -82,6 +82,7 @@ class GrpcReqState: # Streaming state stream_finished: bool = False + input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming # Token accumulation (for non-streaming) output_ids: List[int] = dataclasses.field(default_factory=list) @@ -516,19 +517,105 @@ class GrpcRequestManager: }, } - # Add logprobs if available + # Accumulate input logprobs (only once, usually in first chunk) + if batch_out.input_token_logprobs_val and i < len( + batch_out.input_token_logprobs_val + ): + if not state.input_token_logprobs_val: + state.input_token_logprobs_val.extend( + batch_out.input_token_logprobs_val[i] + ) + if batch_out.input_token_logprobs_idx and i < len( + batch_out.input_token_logprobs_idx + ): + state.input_token_logprobs_idx.extend( + batch_out.input_token_logprobs_idx[i] + ) + if batch_out.input_top_logprobs_val and i < len( + batch_out.input_top_logprobs_val + ): + state.input_top_logprobs_val.extend( + batch_out.input_top_logprobs_val[i] + ) + if batch_out.input_top_logprobs_idx and i < len( + batch_out.input_top_logprobs_idx + ): + state.input_top_logprobs_idx.extend( + batch_out.input_top_logprobs_idx[i] + ) + + # Send input logprobs based on mode + if state.input_token_logprobs_val: + if state.obj.stream and not state.input_logprobs_sent: + # Streaming: send input logprobs once in first chunk that has them + output_data["input_logprobs"] = { + "token_logprobs_val": state.input_token_logprobs_val, + "token_logprobs_idx": state.input_token_logprobs_idx, + "top_logprobs_val": state.input_top_logprobs_val, + "top_logprobs_idx": state.input_top_logprobs_idx, + } + state.input_logprobs_sent = True + elif not state.obj.stream and output_data["finished"]: + # Non-streaming: send input logprobs in final chunk + output_data["input_logprobs"] = { + "token_logprobs_val": state.input_token_logprobs_val, + "token_logprobs_idx": state.input_token_logprobs_idx, + "top_logprobs_val": state.input_top_logprobs_val, + "top_logprobs_idx": state.input_top_logprobs_idx, + } + + # Add output logprobs if available (RAW - no detokenization!) if batch_out.output_token_logprobs_val and i < len( batch_out.output_token_logprobs_val ): - output_data["logprobs"] = { - "tokens": batch_out.output_token_logprobs_val[i], - "top_logprobs": ( + # Accumulate in state first + state.output_token_logprobs_val.extend( + batch_out.output_token_logprobs_val[i] + ) + if batch_out.output_token_logprobs_idx and i < len( + batch_out.output_token_logprobs_idx + ): + state.output_token_logprobs_idx.extend( + batch_out.output_token_logprobs_idx[i] + ) + if batch_out.output_top_logprobs_val and i < len( + batch_out.output_top_logprobs_val + ): + state.output_top_logprobs_val.extend( batch_out.output_top_logprobs_val[i] - if batch_out.output_top_logprobs_val - and i < len(batch_out.output_top_logprobs_val) - else None - ), - } + ) + if batch_out.output_top_logprobs_idx and i < len( + batch_out.output_top_logprobs_idx + ): + state.output_top_logprobs_idx.extend( + batch_out.output_top_logprobs_idx[i] + ) + + if state.obj.stream: + # For streaming: send incremental logprobs (only new tokens in this chunk) + # NOTE: this is different than TokenizerManager, which always accumulates + def get_part(attr_name): + source_list = getattr(batch_out, attr_name, None) + return ( + source_list[i] + if source_list and i < len(source_list) + else [] + ) + + output_data["output_logprobs"] = { + "token_logprobs_val": batch_out.output_token_logprobs_val[i], + "token_logprobs_idx": get_part("output_token_logprobs_idx"), + "top_logprobs_val": get_part("output_top_logprobs_val"), + "top_logprobs_idx": get_part("output_top_logprobs_idx"), + } + elif output_data["finished"]: + # Non-streaming: send cumulative output logprobs in final chunk + output_data["output_logprobs"] = { + "token_logprobs_val": state.output_token_logprobs_val, + "token_logprobs_idx": state.output_token_logprobs_idx, + "top_logprobs_val": state.output_top_logprobs_val, + "top_logprobs_idx": state.output_top_logprobs_idx, + } # Update state for accumulation if output_data["token_ids"]: diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index b360c5068..b772f3067 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -472,11 +472,51 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ignore_eos=grpc_params.ignore_eos, ) + def _convert_logprobs_to_proto( + self, logprobs_data: Dict + ) -> Optional[sglang_scheduler_pb2.LogProbs]: + """Convert logprobs dict to proto LogProbs format (transport RAW data only).""" + if not logprobs_data: + return None + + token_logprobs_val = logprobs_data.get("token_logprobs_val", []) + token_logprobs_idx = logprobs_data.get("token_logprobs_idx", []) + top_logprobs_val = logprobs_data.get("top_logprobs_val", []) + top_logprobs_idx = logprobs_data.get("top_logprobs_idx", []) + + # Build TopLogProbs entries + top_logprobs_proto = [] + if top_logprobs_val and top_logprobs_idx: + for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx): + top_logprobs_proto.append( + sglang_scheduler_pb2.TopLogProbs( + values=val_list, + token_ids=idx_list, + ) + ) + + return sglang_scheduler_pb2.LogProbs( + token_logprobs=token_logprobs_val, + token_ids=token_logprobs_idx, + top_logprobs=top_logprobs_proto, + ) + def _create_chunk_response( self, request_id: str, output: Dict ) -> sglang_scheduler_pb2.GenerateResponse: """Create a streaming chunk response.""" meta_info = output.get("meta_info", {}) + + # Convert output logprobs if present + output_logprobs_proto = self._convert_logprobs_to_proto( + output.get("output_logprobs") + ) + + # Convert input logprobs if present (only in first chunk) + input_logprobs_proto = self._convert_logprobs_to_proto( + output.get("input_logprobs") + ) + return sglang_scheduler_pb2.GenerateResponse( request_id=request_id, chunk=sglang_scheduler_pb2.GenerateStreamChunk( @@ -484,6 +524,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) prompt_tokens=meta_info.get("prompt_tokens", 0), completion_tokens=meta_info.get("completion_tokens", 0), cached_tokens=meta_info.get("cached_tokens", 0), + output_logprobs=output_logprobs_proto, + input_logprobs=input_logprobs_proto, ), ) @@ -519,6 +561,16 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) elif isinstance(matched, str): matched_stop_kwargs["matched_stop_str"] = matched + # Convert output logprobs if present + output_logprobs_proto = self._convert_logprobs_to_proto( + output.get("output_logprobs") + ) + + # Convert input logprobs if present + input_logprobs_proto = self._convert_logprobs_to_proto( + output.get("input_logprobs") + ) + return sglang_scheduler_pb2.GenerateResponse( request_id=request_id, complete=sglang_scheduler_pb2.GenerateComplete( @@ -529,6 +581,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) "completion_tokens", len(output.get("token_ids", [])) ), cached_tokens=meta_info.get("cached_tokens", 0), + output_logprobs=output_logprobs_proto, + input_logprobs=input_logprobs_proto, **matched_stop_kwargs, ), ) diff --git a/python/sglang/srt/grpc/sglang_scheduler.proto b/python/sglang/srt/grpc/sglang_scheduler.proto index 521f43b43..be6508b5a 100644 --- a/python/sglang/srt/grpc/sglang_scheduler.proto +++ b/python/sglang/srt/grpc/sglang_scheduler.proto @@ -174,11 +174,14 @@ message GenerateStreamChunk { int32 completion_tokens = 3; int32 cached_tokens = 4; - // Logprobs (if requested) - LogProbs logprobs = 5; + // Output logprobs (if requested) - incremental for streaming + LogProbs output_logprobs = 5; // Hidden states (if requested) repeated float hidden_states = 6; + + // Input logprobs (if requested) - only in first chunk + LogProbs input_logprobs = 7; } message GenerateComplete { @@ -193,8 +196,8 @@ message GenerateComplete { int32 completion_tokens = 4; int32 cached_tokens = 5; - // All logprobs if requested - repeated LogProbs all_logprobs = 6; + // Output logprobs if requested (cumulative) + LogProbs output_logprobs = 6; // All hidden states if requested repeated HiddenStates all_hidden_states = 7; @@ -204,6 +207,9 @@ message GenerateComplete { uint32 matched_token_id = 8; string matched_stop_str = 9; } + + // Input logprobs if requested (for prompt tokens) + LogProbs input_logprobs = 10; } message GenerateError { @@ -218,15 +224,11 @@ message LogProbs { // Top logprobs at each position repeated TopLogProbs top_logprobs = 3; - - // Decoded text for tokens - repeated string token_texts = 4; } message TopLogProbs { repeated float values = 1; repeated int32 token_ids = 2; - repeated string token_texts = 3; } message HiddenStates { diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py index 0f12fcd41..2f80f83bb 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.py +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py @@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xbb\x01\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x31\n\x08logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\"\xc5\x02\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x35\n\x0c\x61ll_logprobs\x18\x06 \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x42\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xfb\x01\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x38\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12\x37\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x38\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12\x37\n\x0einput_logprobs\x18\n \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbsB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"o\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -53,55 +53,55 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_GENERATERESPONSE']._serialized_start=1858 _globals['_GENERATERESPONSE']._serialized_end=2085 _globals['_GENERATESTREAMCHUNK']._serialized_start=2088 - _globals['_GENERATESTREAMCHUNK']._serialized_end=2275 - _globals['_GENERATECOMPLETE']._serialized_start=2278 - _globals['_GENERATECOMPLETE']._serialized_end=2603 - _globals['_GENERATEERROR']._serialized_start=2605 - _globals['_GENERATEERROR']._serialized_end=2680 - _globals['_LOGPROBS']._serialized_start=2683 - _globals['_LOGPROBS']._serialized_end=2815 - _globals['_TOPLOGPROBS']._serialized_start=2817 - _globals['_TOPLOGPROBS']._serialized_end=2886 - _globals['_HIDDENSTATES']._serialized_start=2888 - _globals['_HIDDENSTATES']._serialized_end=2951 - _globals['_EMBEDREQUEST']._serialized_start=2954 - _globals['_EMBEDREQUEST']._serialized_end=3284 - _globals['_EMBEDRESPONSE']._serialized_start=3287 - _globals['_EMBEDRESPONSE']._serialized_end=3444 - _globals['_EMBEDCOMPLETE']._serialized_start=3447 - _globals['_EMBEDCOMPLETE']._serialized_end=3610 - _globals['_EMBEDDING']._serialized_start=3612 - _globals['_EMBEDDING']._serialized_end=3654 - _globals['_EMBEDERROR']._serialized_start=3656 - _globals['_EMBEDERROR']._serialized_end=3716 - _globals['_HEALTHCHECKREQUEST']._serialized_start=3718 - _globals['_HEALTHCHECKREQUEST']._serialized_end=3796 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=3798 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=3853 - _globals['_ABORTREQUEST']._serialized_start=3855 - _globals['_ABORTREQUEST']._serialized_end=3905 - _globals['_ABORTRESPONSE']._serialized_start=3907 - _globals['_ABORTRESPONSE']._serialized_end=3956 - _globals['_LOADLORAREQUEST']._serialized_start=3958 - _globals['_LOADLORAREQUEST']._serialized_end=4031 - _globals['_LOADLORARESPONSE']._serialized_start=4033 - _globals['_LOADLORARESPONSE']._serialized_end=4105 - _globals['_UNLOADLORAREQUEST']._serialized_start=4107 - _globals['_UNLOADLORAREQUEST']._serialized_end=4146 - _globals['_UNLOADLORARESPONSE']._serialized_start=4148 - _globals['_UNLOADLORARESPONSE']._serialized_end=4202 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4204 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4323 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4325 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4382 - _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4384 - _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4429 - _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4431 - _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4497 - _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4499 - _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4564 - _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4566 - _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4626 - _globals['_SGLANGSCHEDULER']._serialized_start=4629 - _globals['_SGLANGSCHEDULER']._serialized_end=5011 + _globals['_GENERATESTREAMCHUNK']._serialized_end=2339 + _globals['_GENERATECOMPLETE']._serialized_start=2342 + _globals['_GENERATECOMPLETE']._serialized_end=2727 + _globals['_GENERATEERROR']._serialized_start=2729 + _globals['_GENERATEERROR']._serialized_end=2804 + _globals['_LOGPROBS']._serialized_start=2806 + _globals['_LOGPROBS']._serialized_end=2917 + _globals['_TOPLOGPROBS']._serialized_start=2919 + _globals['_TOPLOGPROBS']._serialized_end=2967 + _globals['_HIDDENSTATES']._serialized_start=2969 + _globals['_HIDDENSTATES']._serialized_end=3032 + _globals['_EMBEDREQUEST']._serialized_start=3035 + _globals['_EMBEDREQUEST']._serialized_end=3365 + _globals['_EMBEDRESPONSE']._serialized_start=3368 + _globals['_EMBEDRESPONSE']._serialized_end=3525 + _globals['_EMBEDCOMPLETE']._serialized_start=3528 + _globals['_EMBEDCOMPLETE']._serialized_end=3691 + _globals['_EMBEDDING']._serialized_start=3693 + _globals['_EMBEDDING']._serialized_end=3735 + _globals['_EMBEDERROR']._serialized_start=3737 + _globals['_EMBEDERROR']._serialized_end=3797 + _globals['_HEALTHCHECKREQUEST']._serialized_start=3799 + _globals['_HEALTHCHECKREQUEST']._serialized_end=3877 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=3879 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=3934 + _globals['_ABORTREQUEST']._serialized_start=3936 + _globals['_ABORTREQUEST']._serialized_end=3986 + _globals['_ABORTRESPONSE']._serialized_start=3988 + _globals['_ABORTRESPONSE']._serialized_end=4037 + _globals['_LOADLORAREQUEST']._serialized_start=4039 + _globals['_LOADLORAREQUEST']._serialized_end=4112 + _globals['_LOADLORARESPONSE']._serialized_start=4114 + _globals['_LOADLORARESPONSE']._serialized_end=4186 + _globals['_UNLOADLORAREQUEST']._serialized_start=4188 + _globals['_UNLOADLORAREQUEST']._serialized_end=4227 + _globals['_UNLOADLORARESPONSE']._serialized_start=4229 + _globals['_UNLOADLORARESPONSE']._serialized_end=4283 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4285 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4404 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4406 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4463 + _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4465 + _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4510 + _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4512 + _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4578 + _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4580 + _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4645 + _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4647 + _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4707 + _globals['_SGLANGSCHEDULER']._serialized_start=4710 + _globals['_SGLANGSCHEDULER']._serialized_end=5092 # @@protoc_insertion_point(module_scope) diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi index 7ca94db25..3578abe74 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi @@ -162,42 +162,46 @@ class GenerateResponse(_message.Message): def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... class GenerateStreamChunk(_message.Message): - __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states") + __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs") TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] - LOGPROBS_FIELD_NUMBER: _ClassVar[int] + OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] + INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] token_ids: _containers.RepeatedScalarFieldContainer[int] prompt_tokens: int completion_tokens: int cached_tokens: int - logprobs: LogProbs + output_logprobs: LogProbs hidden_states: _containers.RepeatedScalarFieldContainer[float] - def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ... + input_logprobs: LogProbs + def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ... class GenerateComplete(_message.Message): - __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str") + __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs") OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] FINISH_REASON_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] - ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int] + INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] output_ids: _containers.RepeatedScalarFieldContainer[int] finish_reason: str prompt_tokens: int completion_tokens: int cached_tokens: int - all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs] + output_logprobs: LogProbs all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates] matched_token_id: int matched_stop_str: str - def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ... + input_logprobs: LogProbs + def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ... class GenerateError(_message.Message): __slots__ = ("message", "http_status_code", "details") @@ -210,26 +214,22 @@ class GenerateError(_message.Message): def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... class LogProbs(_message.Message): - __slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts") + __slots__ = ("token_logprobs", "token_ids", "top_logprobs") TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] - TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int] token_logprobs: _containers.RepeatedScalarFieldContainer[float] token_ids: _containers.RepeatedScalarFieldContainer[int] top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs] - token_texts: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ... + def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ... class TopLogProbs(_message.Message): - __slots__ = ("values", "token_ids", "token_texts") + __slots__ = ("values", "token_ids") VALUES_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] - TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int] values: _containers.RepeatedScalarFieldContainer[float] token_ids: _containers.RepeatedScalarFieldContainer[int] - token_texts: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ... + def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ... class HiddenStates(_message.Message): __slots__ = ("values", "layer", "position") diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index 521f43b43..be6508b5a 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -174,11 +174,14 @@ message GenerateStreamChunk { int32 completion_tokens = 3; int32 cached_tokens = 4; - // Logprobs (if requested) - LogProbs logprobs = 5; + // Output logprobs (if requested) - incremental for streaming + LogProbs output_logprobs = 5; // Hidden states (if requested) repeated float hidden_states = 6; + + // Input logprobs (if requested) - only in first chunk + LogProbs input_logprobs = 7; } message GenerateComplete { @@ -193,8 +196,8 @@ message GenerateComplete { int32 completion_tokens = 4; int32 cached_tokens = 5; - // All logprobs if requested - repeated LogProbs all_logprobs = 6; + // Output logprobs if requested (cumulative) + LogProbs output_logprobs = 6; // All hidden states if requested repeated HiddenStates all_hidden_states = 7; @@ -204,6 +207,9 @@ message GenerateComplete { uint32 matched_token_id = 8; string matched_stop_str = 9; } + + // Input logprobs if requested (for prompt tokens) + LogProbs input_logprobs = 10; } message GenerateError { @@ -218,15 +224,11 @@ message LogProbs { // Top logprobs at each position repeated TopLogProbs top_logprobs = 3; - - // Decoded text for tokens - repeated string token_texts = 4; } message TopLogProbs { repeated float values = 1; repeated int32 token_ids = 2; - repeated string token_texts = 3; } message HiddenStates { diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index d3d58e263..dce2ca6f7 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -730,6 +730,73 @@ impl GrpcRouter { Json(response).into_response() } + /// Convert proto LogProbs to OpenAI ChatLogProbs format + /// Note: Always decodes with skip_special_tokens=false to show actual tokens generated + fn convert_proto_to_openai_logprobs( + &self, + proto_logprobs: &proto::LogProbs, + ) -> Result { + let mut content_items = Vec::new(); + + // Decode token IDs to text (always with skip_special_tokens=false for logprobs) + let token_texts: Vec = proto_logprobs + .token_ids + .iter() + .map(|&token_id| { + self.tokenizer + .decode(&[token_id as u32], false) + .unwrap_or_else(|_| format!("", token_id)) + }) + .collect(); + + // Build ChatLogProbsContent for each token + for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() { + let token_text = token_texts.get(i).cloned().unwrap_or_default(); + let bytes = Some(token_text.as_bytes().to_vec()); + + // Build top_logprobs for this position + let mut top_logprobs = Vec::new(); + if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) { + // Decode top token IDs (always with skip_special_tokens=false) + let top_token_texts: Vec = top_logprobs_entry + .token_ids + .iter() + .map(|&tid| { + self.tokenizer + .decode(&[tid as u32], false) + .unwrap_or_else(|_| format!("", tid)) + }) + .collect(); + + for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry + .values + .iter() + .zip(top_logprobs_entry.token_ids.iter()) + .enumerate() + { + if let Some(top_token_text) = top_token_texts.get(j) { + top_logprobs.push(crate::protocols::spec::TopLogProb { + token: top_token_text.clone(), + logprob: top_logprob, + bytes: Some(top_token_text.as_bytes().to_vec()), + }); + } + } + } + + content_items.push(crate::protocols::spec::ChatLogProbsContent { + token: token_text, + logprob, + bytes, + top_logprobs, + }); + } + + Ok(crate::protocols::spec::ChatLogProbs::Detailed { + content: (!content_items.is_empty()).then_some(content_items), + }) + } + /// Process a single GenerateComplete response into a ChatChoice async fn process_single_choice( &self, @@ -855,7 +922,22 @@ impl GrpcRouter { None => None, }; - // Step 4: Build ChatCompletionMessage (proper response message type) + // Step 4: Convert output logprobs if present + // Note: complete.input_logprobs exists in proto but is not used for chat completions + // (input logprobs are only used in /v1/completions endpoint with echo=true) + let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs { + match self.convert_proto_to_openai_logprobs(proto_logprobs) { + Ok(logprobs) => Some(logprobs), + Err(e) => { + error!("Failed to convert logprobs: {}", e); + None + } + } + } else { + None + }; + + // Step 5: Build ChatCompletionMessage (proper response message type) let chat_message = ChatCompletionMessage { role: "assistant".to_string(), content: if processed_text.is_empty() { @@ -867,11 +949,11 @@ impl GrpcRouter { reasoning_content: reasoning_text, }; - // Step 5: Build ChatChoice + // Step 6: Build ChatChoice let choice = ChatChoice { index: index as u32, message: chat_message, - logprobs: None, + logprobs, finish_reason: Some(final_finish_reason_str.to_string()), matched_stop, hidden_states: None,