From 963175d5c06a053399b2cf361295372ff1158e46 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 2 Oct 2025 14:35:16 -0700 Subject: [PATCH] [router][grpc] Support streaming for v1/chat/completions (#11179) --- .../srt/entrypoints/grpc_request_manager.py | 2 +- python/sglang/srt/entrypoints/grpc_server.py | 3 +- python/sglang/srt/grpc/sglang_scheduler.proto | 6 + .../sglang/srt/grpc/sglang_scheduler_pb2.py | 112 +-- .../sglang/srt/grpc/sglang_scheduler_pb2.pyi | 12 +- sgl-router/benches/request_processing.rs | 1 - sgl-router/src/proto/sglang_scheduler.proto | 6 + sgl-router/src/protocols/spec.rs | 6 +- sgl-router/src/reasoning_parser/README.md | 12 +- sgl-router/src/reasoning_parser/factory.rs | 32 +- sgl-router/src/reasoning_parser/mod.rs | 2 +- sgl-router/src/routers/grpc/pd_router.rs | 4 +- sgl-router/src/routers/grpc/router.rs | 770 ++++++++++++++++-- sgl-router/src/server.rs | 6 +- .../tool_parser/parsers/deepseek_parser.rs | 9 +- .../tool_parser/parsers/glm4_moe_parser.rs | 9 +- .../src/tool_parser/parsers/gpt_oss_parser.rs | 5 - sgl-router/src/tool_parser/parsers/helpers.rs | 42 + .../src/tool_parser/parsers/json_parser.rs | 15 +- .../src/tool_parser/parsers/kimik2_parser.rs | 9 +- .../src/tool_parser/parsers/llama_parser.rs | 14 +- .../src/tool_parser/parsers/mistral_parser.rs | 13 +- .../tool_parser/parsers/pythonic_parser.rs | 4 +- .../src/tool_parser/parsers/qwen_parser.rs | 13 +- .../src/tool_parser/parsers/step3_parser.rs | 9 +- sgl-router/src/tool_parser/tests.rs | 8 +- sgl-router/src/tool_parser/traits.rs | 6 + sgl-router/src/tool_parser/types.rs | 7 +- sgl-router/tests/chat_template_integration.rs | 1 - sgl-router/tests/chat_template_loading.rs | 2 - 30 files changed, 912 insertions(+), 228 deletions(-) diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index 71d77bfc2..1296e810a 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -578,7 +578,7 @@ class GrpcRequestManager: batch_out.cached_tokens[i] if batch_out.cached_tokens else 0 ), "finish_reason": ( - str(batch_out.finished_reasons[i]) + batch_out.finished_reasons[i] if batch_out.finished_reasons[i] else None ), diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index 55712c177..c143158bb 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -112,7 +112,6 @@ def _launch_scheduler_process_only( pp_rank, None, writer, - None, ), ) @@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) cached_tokens=meta_info.get("cached_tokens", 0), output_logprobs=output_logprobs_proto, input_logprobs=input_logprobs_proto, + index=output.get("index", 0), ), ) @@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) cached_tokens=meta_info.get("cached_tokens", 0), output_logprobs=output_logprobs_proto, input_logprobs=input_logprobs_proto, + index=output.get("index", 0), **matched_stop_kwargs, ), ) diff --git a/python/sglang/srt/grpc/sglang_scheduler.proto b/python/sglang/srt/grpc/sglang_scheduler.proto index 1d52e54c5..152e65cb5 100644 --- a/python/sglang/srt/grpc/sglang_scheduler.proto +++ b/python/sglang/srt/grpc/sglang_scheduler.proto @@ -179,6 +179,9 @@ message GenerateStreamChunk { // Input logprobs (if requested) - only in first chunk InputLogProbs input_logprobs = 7; + + // Index for ordering when n>1 (for parallel request multiplexing) + uint32 index = 8; } message GenerateComplete { @@ -207,6 +210,9 @@ message GenerateComplete { // Input logprobs if requested (for prompt tokens) InputLogProbs input_logprobs = 10; + + // Index for ordering when n>1 (for parallel request multiplexing) + uint32 index = 11; } message GenerateError { diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py index 2dc5ac321..10127567b 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.py +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py @@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x86\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\"\x8c\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbsB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x95\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x08 \x01(\r\"\x9b\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x0b \x01(\rB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -53,59 +53,59 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_GENERATERESPONSE']._serialized_start=1835 _globals['_GENERATERESPONSE']._serialized_end=2062 _globals['_GENERATESTREAMCHUNK']._serialized_start=2065 - _globals['_GENERATESTREAMCHUNK']._serialized_end=2327 - _globals['_GENERATECOMPLETE']._serialized_start=2330 - _globals['_GENERATECOMPLETE']._serialized_end=2726 - _globals['_GENERATEERROR']._serialized_start=2728 - _globals['_GENERATEERROR']._serialized_end=2803 - _globals['_OUTPUTLOGPROBS']._serialized_start=2805 - _globals['_OUTPUTLOGPROBS']._serialized_end=2922 - _globals['_INPUTLOGPROBS']._serialized_start=2925 - _globals['_INPUTLOGPROBS']._serialized_end=3083 - _globals['_INPUTTOKENLOGPROB']._serialized_start=3085 - _globals['_INPUTTOKENLOGPROB']._serialized_end=3134 - _globals['_TOPLOGPROBS']._serialized_start=3136 - _globals['_TOPLOGPROBS']._serialized_end=3184 - _globals['_HIDDENSTATES']._serialized_start=3186 - _globals['_HIDDENSTATES']._serialized_end=3249 - _globals['_EMBEDREQUEST']._serialized_start=3252 - _globals['_EMBEDREQUEST']._serialized_end=3582 - _globals['_EMBEDRESPONSE']._serialized_start=3585 - _globals['_EMBEDRESPONSE']._serialized_end=3742 - _globals['_EMBEDCOMPLETE']._serialized_start=3745 - _globals['_EMBEDCOMPLETE']._serialized_end=3908 - _globals['_EMBEDDING']._serialized_start=3910 - _globals['_EMBEDDING']._serialized_end=3952 - _globals['_EMBEDERROR']._serialized_start=3954 - _globals['_EMBEDERROR']._serialized_end=4014 - _globals['_HEALTHCHECKREQUEST']._serialized_start=4016 - _globals['_HEALTHCHECKREQUEST']._serialized_end=4094 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=4096 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=4151 - _globals['_ABORTREQUEST']._serialized_start=4153 - _globals['_ABORTREQUEST']._serialized_end=4203 - _globals['_ABORTRESPONSE']._serialized_start=4205 - _globals['_ABORTRESPONSE']._serialized_end=4254 - _globals['_LOADLORAREQUEST']._serialized_start=4256 - _globals['_LOADLORAREQUEST']._serialized_end=4329 - _globals['_LOADLORARESPONSE']._serialized_start=4331 - _globals['_LOADLORARESPONSE']._serialized_end=4403 - _globals['_UNLOADLORAREQUEST']._serialized_start=4405 - _globals['_UNLOADLORAREQUEST']._serialized_end=4444 - _globals['_UNLOADLORARESPONSE']._serialized_start=4446 - _globals['_UNLOADLORARESPONSE']._serialized_end=4500 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4502 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4621 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4623 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4680 - _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4682 - _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4727 - _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4729 - _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4795 - _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4797 - _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4862 - _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4864 - _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4924 - _globals['_SGLANGSCHEDULER']._serialized_start=4927 - _globals['_SGLANGSCHEDULER']._serialized_end=5309 + _globals['_GENERATESTREAMCHUNK']._serialized_end=2342 + _globals['_GENERATECOMPLETE']._serialized_start=2345 + _globals['_GENERATECOMPLETE']._serialized_end=2756 + _globals['_GENERATEERROR']._serialized_start=2758 + _globals['_GENERATEERROR']._serialized_end=2833 + _globals['_OUTPUTLOGPROBS']._serialized_start=2835 + _globals['_OUTPUTLOGPROBS']._serialized_end=2952 + _globals['_INPUTLOGPROBS']._serialized_start=2955 + _globals['_INPUTLOGPROBS']._serialized_end=3113 + _globals['_INPUTTOKENLOGPROB']._serialized_start=3115 + _globals['_INPUTTOKENLOGPROB']._serialized_end=3164 + _globals['_TOPLOGPROBS']._serialized_start=3166 + _globals['_TOPLOGPROBS']._serialized_end=3214 + _globals['_HIDDENSTATES']._serialized_start=3216 + _globals['_HIDDENSTATES']._serialized_end=3279 + _globals['_EMBEDREQUEST']._serialized_start=3282 + _globals['_EMBEDREQUEST']._serialized_end=3612 + _globals['_EMBEDRESPONSE']._serialized_start=3615 + _globals['_EMBEDRESPONSE']._serialized_end=3772 + _globals['_EMBEDCOMPLETE']._serialized_start=3775 + _globals['_EMBEDCOMPLETE']._serialized_end=3938 + _globals['_EMBEDDING']._serialized_start=3940 + _globals['_EMBEDDING']._serialized_end=3982 + _globals['_EMBEDERROR']._serialized_start=3984 + _globals['_EMBEDERROR']._serialized_end=4044 + _globals['_HEALTHCHECKREQUEST']._serialized_start=4046 + _globals['_HEALTHCHECKREQUEST']._serialized_end=4124 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=4126 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=4181 + _globals['_ABORTREQUEST']._serialized_start=4183 + _globals['_ABORTREQUEST']._serialized_end=4233 + _globals['_ABORTRESPONSE']._serialized_start=4235 + _globals['_ABORTRESPONSE']._serialized_end=4284 + _globals['_LOADLORAREQUEST']._serialized_start=4286 + _globals['_LOADLORAREQUEST']._serialized_end=4359 + _globals['_LOADLORARESPONSE']._serialized_start=4361 + _globals['_LOADLORARESPONSE']._serialized_end=4433 + _globals['_UNLOADLORAREQUEST']._serialized_start=4435 + _globals['_UNLOADLORAREQUEST']._serialized_end=4474 + _globals['_UNLOADLORARESPONSE']._serialized_start=4476 + _globals['_UNLOADLORARESPONSE']._serialized_end=4530 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4532 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4651 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4653 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4710 + _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4712 + _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4757 + _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4759 + _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4825 + _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4827 + _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4892 + _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4894 + _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4954 + _globals['_SGLANGSCHEDULER']._serialized_start=4957 + _globals['_SGLANGSCHEDULER']._serialized_end=5339 # @@protoc_insertion_point(module_scope) diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi index 167b68d96..53559ebfd 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi @@ -160,7 +160,7 @@ class GenerateResponse(_message.Message): def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... class GenerateStreamChunk(_message.Message): - __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs") + __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index") TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] @@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message): OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + INDEX_FIELD_NUMBER: _ClassVar[int] token_ids: _containers.RepeatedScalarFieldContainer[int] prompt_tokens: int completion_tokens: int @@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message): output_logprobs: OutputLogProbs hidden_states: _containers.RepeatedScalarFieldContainer[float] input_logprobs: InputLogProbs - def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ... + index: int + def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ... class GenerateComplete(_message.Message): - __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs") + __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index") OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] FINISH_REASON_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] @@ -189,6 +191,7 @@ class GenerateComplete(_message.Message): MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int] INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + INDEX_FIELD_NUMBER: _ClassVar[int] output_ids: _containers.RepeatedScalarFieldContainer[int] finish_reason: str prompt_tokens: int @@ -199,7 +202,8 @@ class GenerateComplete(_message.Message): matched_token_id: int matched_stop_str: str input_logprobs: InputLogProbs - def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ... + index: int + def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ... class GenerateError(_message.Message): __slots__ = ("message", "http_status_code", "details") diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 5c8aa389d..2a1163deb 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)), name: None, tool_calls: None, - function_call: None, reasoning_content: None, }); } diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index 1d52e54c5..152e65cb5 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -179,6 +179,9 @@ message GenerateStreamChunk { // Input logprobs (if requested) - only in first chunk InputLogProbs input_logprobs = 7; + + // Index for ordering when n>1 (for parallel request multiplexing) + uint32 index = 8; } message GenerateComplete { @@ -207,6 +210,9 @@ message GenerateComplete { // Input logprobs if requested (for prompt tokens) InputLogProbs input_logprobs = 10; + + // Index for ordering when n>1 (for parallel request multiplexing) + uint32 index = 11; } message GenerateError { diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 994f2c434..fc4b9854b 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -72,8 +72,6 @@ pub enum ChatMessage { name: Option, #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, /// Reasoning content for O1-style models (SGLang extension) #[serde(skip_serializing_if = "Option::is_none")] reasoning_content: Option, @@ -140,8 +138,6 @@ pub struct ChatMessageDelta { pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, /// Reasoning content delta for O1-style models (SGLang extension) #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, @@ -473,6 +469,8 @@ pub struct ChatStreamChoice { #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, } // Completions API request types (v1/completions) - DEPRECATED but still supported diff --git a/sgl-router/src/reasoning_parser/README.md b/sgl-router/src/reasoning_parser/README.md index 763028f0c..5db6f11f1 100644 --- a/sgl-router/src/reasoning_parser/README.md +++ b/sgl-router/src/reasoning_parser/README.md @@ -44,7 +44,7 @@ graph TB end subgraph Factory Layer - MID --> PF[ParserFactory] + MID --> PF[ReasoningParserFactory] PF --> REG[ParserRegistry] REG --> PM[Pattern Matching] PM --> PP[Parser Pool] @@ -93,7 +93,7 @@ graph TB ```mermaid sequenceDiagram participant C as Client - participant F as ParserFactory + participant F as ReasoningParserFactory participant R as Registry participant P as Parser Pool participant BP as BaseParser @@ -206,7 +206,7 @@ classDiagram +new() Self } - class ParserFactory { + class ReasoningParserFactory { -registry: ParserRegistry +new() Self +get_pooled(model_id: &str) PooledParser @@ -240,7 +240,7 @@ classDiagram Step3Parser o-- BaseReasoningParser BaseReasoningParser o-- ParserConfig - ParserFactory o-- ParserRegistry + ReasoningParserFactory o-- ParserRegistry ParserRegistry o-- ReasoningParser ``` @@ -302,7 +302,7 @@ classDiagram - Delegate to get_pooled_parser - Case-insensitive comparison -**ParserFactory Methods**: +**ReasoningParserFactory Methods**: 1. **`new()`**: - Register all built-in parsers @@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser { **Step 2: Register in Factory** ```rust -// In factory.rs ParserFactory::new() +// In factory.rs ReasoningParserFactory::new() registry.register_parser("mymodel", || { Box::new(MyModelParser::new()) }); diff --git a/sgl-router/src/reasoning_parser/factory.rs b/sgl-router/src/reasoning_parser/factory.rs index 884221710..f7ea9f3fa 100644 --- a/sgl-router/src/reasoning_parser/factory.rs +++ b/sgl-router/src/reasoning_parser/factory.rs @@ -128,11 +128,11 @@ impl Default for ParserRegistry { /// Factory for creating reasoning parsers based on model type. #[derive(Clone)] -pub struct ParserFactory { +pub struct ReasoningParserFactory { registry: ParserRegistry, } -impl ParserFactory { +impl ReasoningParserFactory { /// Create a new factory with default parsers registered. pub fn new() -> Self { let registry = ParserRegistry::new(); @@ -237,7 +237,7 @@ impl ParserFactory { } } -impl Default for ParserFactory { +impl Default for ReasoningParserFactory { fn default() -> Self { Self::new() } @@ -249,35 +249,35 @@ mod tests { #[test] fn test_factory_creates_deepseek_r1() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let parser = factory.create("deepseek-r1-distill").unwrap(); assert_eq!(parser.model_type(), "deepseek_r1"); } #[test] fn test_factory_creates_qwen3() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let parser = factory.create("qwen3-7b").unwrap(); assert_eq!(parser.model_type(), "qwen3"); } #[test] fn test_factory_creates_kimi() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let parser = factory.create("kimi-chat").unwrap(); assert_eq!(parser.model_type(), "kimi"); } #[test] fn test_factory_fallback_to_passthrough() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let parser = factory.create("unknown-model").unwrap(); assert_eq!(parser.model_type(), "passthrough"); } #[test] fn test_case_insensitive_matching() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let parser1 = factory.create("DeepSeek-R1").unwrap(); let parser2 = factory.create("QWEN3").unwrap(); let parser3 = factory.create("Kimi").unwrap(); @@ -289,21 +289,21 @@ mod tests { #[test] fn test_step3_model() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let step3 = factory.create("step3-model").unwrap(); assert_eq!(step3.model_type(), "step3"); } #[test] fn test_glm45_model() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let glm45 = factory.create("glm45-v2").unwrap(); assert_eq!(glm45.model_type(), "glm45"); } #[test] fn test_pooled_parser_reuse() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); // Get the same parser twice - should be the same instance let parser1 = factory.get_pooled("deepseek-r1"); @@ -321,7 +321,7 @@ mod tests { fn test_pooled_parser_concurrent_access() { use std::thread; - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let parser = factory.get_pooled("deepseek-r1"); // Spawn multiple threads that use the same parser @@ -347,7 +347,7 @@ mod tests { #[test] fn test_pool_clearing() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); // Get a pooled parser let parser1 = factory.get_pooled("deepseek-r1"); @@ -364,7 +364,7 @@ mod tests { #[test] fn test_passthrough_parser_pooling() { - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); // Unknown models should get passthrough parser let parser1 = factory.get_pooled("unknown-model-1"); @@ -383,7 +383,7 @@ mod tests { use std::thread; use std::time::Instant; - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let num_threads = 100; let requests_per_thread = 50; let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; @@ -527,7 +527,7 @@ mod tests { fn test_concurrent_pool_modifications() { use std::thread; - let factory = ParserFactory::new(); + let factory = ReasoningParserFactory::new(); let mut handles = vec![]; // Thread 1: Continuously get parsers diff --git a/sgl-router/src/reasoning_parser/mod.rs b/sgl-router/src/reasoning_parser/mod.rs index 95ffcbc4f..8cc7e8357 100644 --- a/sgl-router/src/reasoning_parser/mod.rs +++ b/sgl-router/src/reasoning_parser/mod.rs @@ -2,7 +2,7 @@ pub mod factory; pub mod parsers; pub mod traits; -pub use factory::{ParserFactory, ParserRegistry, PooledParser}; +pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory}; pub use parsers::{ BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, QwenThinkingParser, Step3Parser, diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index d60a771a4..135260f06 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -4,7 +4,7 @@ use crate::config::types::RetryConfig; use crate::core::{WorkerRegistry, WorkerType}; use crate::metrics::RouterMetrics; use crate::policies::PolicyRegistry; -use crate::reasoning_parser::ParserFactory; +use crate::reasoning_parser::ReasoningParserFactory; use crate::routers::RouterTrait; use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ToolParserFactory; @@ -24,7 +24,7 @@ pub struct GrpcPDRouter { worker_registry: Arc, policy_registry: Arc, tokenizer: Arc, - reasoning_parser_factory: ParserFactory, + reasoning_parser_factory: ReasoningParserFactory, tool_parser_factory: ToolParserFactory, dp_aware: bool, diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 9b749b52c..0c63ff66e 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -7,10 +7,14 @@ use async_trait::async_trait; use axum::{ body::Body, extract::Request, - http::{HeaderMap, StatusCode}, + http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, Json, }; +use bytes::Bytes; +use std::io; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; use crate::config::types::RetryConfig; @@ -21,11 +25,12 @@ use crate::policies::PolicyRegistry; use crate::protocols::spec::ChatMessage; use crate::protocols::spec::{ ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest, - ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice, + ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, CompletionRequest, + EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, RerankRequest, + ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, Usage, }; -use crate::reasoning_parser::ParserFactory; +use crate::reasoning_parser::{ParserResult, ReasoningParserFactory}; use crate::routers::RouterTrait; use crate::server::AppContext; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; @@ -34,7 +39,7 @@ use crate::tokenizer::stop::{ }; use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::HuggingFaceTokenizer; -use crate::tool_parser::ToolParserFactory; +use crate::tool_parser::{StreamingParseResult, ToolParserFactory}; use proto::generate_response::Response::{Chunk, Complete, Error}; use serde_json::{json, Map, Value}; use std::time::{Instant, SystemTime, UNIX_EPOCH}; @@ -50,12 +55,13 @@ pub struct ProcessedMessages { } /// gRPC router implementation for SGLang +#[derive(Clone)] #[allow(dead_code)] pub struct GrpcRouter { worker_registry: Arc, policy_registry: Arc, tokenizer: Arc, - reasoning_parser_factory: ParserFactory, + reasoning_parser_factory: ReasoningParserFactory, tool_parser_factory: ToolParserFactory, dp_aware: bool, api_key: Option, @@ -776,10 +782,11 @@ impl GrpcRouter { } /// Parse tool calls using model-specific parser - async fn parse_with_model_parser( + async fn parse_tool_calls( &self, processed_text: &str, model: &str, + history_tool_calls_count: usize, ) -> (Option>, String) { // Get pooled parser for this model let pooled_parser = self.tool_parser_factory.get_pooled(model); @@ -810,16 +817,26 @@ impl GrpcRouter { let spec_tool_calls = parsed_tool_calls .into_iter() - .map(|tc| ToolCall { - id: tc.id, - tool_type: "function".to_string(), - function: FunctionCallResponse { - name: tc.function.name, - arguments: Some( - serde_json::to_string(&tc.function.arguments) - .unwrap_or_else(|_| "{}".to_string()), - ), - }, + .enumerate() + .map(|(index, tc)| { + // Generate ID for this tool call + let id = Self::generate_tool_call_id( + model, + &tc.function.name, + index, + history_tool_calls_count, + ); + ToolCall { + id, + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: tc.function.name, + arguments: Some( + serde_json::to_string(&tc.function.arguments) + .unwrap_or_else(|_| "{}".to_string()), + ), + }, + } }) .collect(); (Some(spec_tool_calls), normal_text) @@ -920,6 +937,47 @@ impl GrpcRouter { builder.build() } + /// Count the number of tool calls in the request message history + /// This is used for KimiK2 format which needs globally unique indices + fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize { + request + .messages + .iter() + .filter_map(|msg| { + if let ChatMessage::Assistant { tool_calls, .. } = msg { + tool_calls.as_ref().map(|calls| calls.len()) + } else { + None + } + }) + .sum() + } + + /// Generate a tool call ID based on model format + /// + /// # Arguments + /// * `model` - Model name to determine ID format + /// * `tool_name` - Name of the tool being called + /// * `tool_index` - Index of this tool call within the current message + /// * `history_count` - Number of tool calls in previous messages + /// + /// # Returns + /// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}` + fn generate_tool_call_id( + model: &str, + tool_name: &str, + tool_index: usize, + history_count: usize, + ) -> String { + if model.to_lowercase().contains("kimi") { + // KimiK2 format: functions.{name}:{global_index} + format!("functions.{}:{}", tool_name, history_count + tool_index) + } else { + // Standard OpenAI format: call_{24-char-uuid} + format!("call_{}", &Uuid::new_v4().simple().to_string()[..24]) + } + } + /// Process a chunk of tokens through the stop decoder fn process_chunk_tokens( stop_decoder: &mut StopSequenceDecoder, @@ -953,6 +1011,230 @@ impl GrpcRouter { (chunk_text, false) // Return text and continue processing } + /// Helper: Process reasoning content in streaming mode + /// Returns (modified_delta, optional_reasoning_chunk) + fn process_reasoning_stream( + &self, + delta: &str, + index: u32, + reasoning_parsers: &mut HashMap< + u32, + Arc>>, + >, + request_id: &str, + model: &str, + created: u64, + ) -> (String, Option) { + // Get or create parser for this index + reasoning_parsers + .entry(index) + .or_insert_with(|| self.reasoning_parser_factory.get_pooled(model)); + + if let Some(pooled_parser) = reasoning_parsers.get(&index) { + let parse_result = { + let mut parser = pooled_parser.lock().unwrap(); + parser.parse_reasoning_streaming_incremental(delta) + }; + + match parse_result { + Ok(ParserResult { + reasoning_text, + normal_text, + }) => { + let chunk = if !reasoning_text.is_empty() { + Some(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: Some(reasoning_text), + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }) + } else { + None + }; + return (normal_text, chunk); + } + Err(e) => { + warn!("Reasoning parsing error: {}", e); + } + } + } + + (delta.to_string(), None) + } + + /// Helper: Process tool calls in streaming mode + /// Returns (should_skip_content, chunks_to_emit) + #[allow(clippy::too_many_arguments)] + async fn process_tool_calls_stream( + &self, + delta: &str, + index: u32, + tool_parsers: &mut HashMap< + u32, + Arc>>, + >, + has_tool_calls: &mut HashMap, + tools: &[crate::protocols::spec::Tool], + request_id: &str, + model: &str, + created: u64, + history_tool_calls_count: usize, + ) -> (bool, Vec) { + let mut chunks = Vec::new(); + + // Get or create parser for this index + tool_parsers + .entry(index) + .or_insert_with(|| self.tool_parser_factory.get_pooled(model)); + + if let Some(pooled_parser) = tool_parsers.get(&index) { + let mut parser = pooled_parser.lock().await; + match parser.parse_incremental(delta, tools).await { + Ok(StreamingParseResult { normal_text, calls }) => { + // Emit normal text if present + if !normal_text.is_empty() { + chunks.push(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(normal_text), + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }); + } + + // Emit tool call chunks + for tool_call_item in calls { + has_tool_calls.insert(index, true); + + let tool_call_id = if let Some(ref name) = tool_call_item.name { + Some(Self::generate_tool_call_id( + model, + name, + tool_call_item.tool_index, + history_tool_calls_count, + )) + } else { + None + }; + + let tool_call_delta = ToolCallDelta { + index: tool_call_item.tool_index as u32, + id: tool_call_id, + tool_type: if tool_call_item.name.is_some() { + Some("function".to_string()) + } else { + None + }, + function: Some(FunctionCallDelta { + name: tool_call_item.name, + arguments: if !tool_call_item.parameters.is_empty() { + Some(tool_call_item.parameters) + } else { + None + }, + }), + }; + + chunks.push(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(vec![tool_call_delta]), + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }); + } + + // If we emitted chunks, skip regular content + return (!chunks.is_empty(), chunks); + } + Err(e) => { + warn!("Tool call parsing error: {}", e); + } + } + } + + (false, chunks) + } + + /// Helper: Create content chunk + fn create_content_chunk( + content: String, + index: u32, + request_id: &str, + model: &str, + created: u64, + logprobs: Option, + ) -> ChatCompletionStreamResponse { + ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(content), + tool_calls: None, + reasoning_content: None, + }, + logprobs, + finish_reason: None, + matched_stop: None, + }], + usage: None, + } + } + + /// Helper: Format response as SSE chunk + fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String { + format!( + "data: {}\n\n", + serde_json::to_string(response).unwrap_or_default() + ) + } + /// Submit request and handle streaming response for chat completions route async fn handle_streaming_chat( &self, @@ -960,14 +1242,13 @@ impl GrpcRouter { request: proto::GenerateRequest, original_request: &ChatCompletionRequest, ) -> Response { - let mut stop_decoder = self.create_stop_decoder( - original_request.stop.as_ref(), - original_request.stop_token_ids.as_ref(), - original_request.skip_special_tokens, - original_request.no_stop_trim, - ); + let request_id = request.request_id.clone(); + let model = original_request.model.clone(); - // Process streaming tokens + // Create channel for SSE streaming + let (tx, rx) = mpsc::unbounded_channel::>(); + + // Start the gRPC stream let mut grpc_stream = match client.generate(request).await { Ok(stream) => stream, Err(e) => { @@ -980,49 +1261,414 @@ impl GrpcRouter { } }; - let mut decoded_text = String::new(); + let stop_params = ( + original_request.stop.clone(), + original_request.stop_token_ids.clone(), + original_request.skip_special_tokens, + original_request.no_stop_trim, + ); + // Spawn processing task + let self_clone = self.clone(); + let original_request_clone = original_request.clone(); + tokio::spawn(async move { + let result = Self::process_streaming_chunks( + &self_clone, + &mut grpc_stream, + request_id, + model, + stop_params, + original_request_clone, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!( + "data: {}\n\n", + json!({ + "error": { + "message": e, + "type": "internal_error" + } + }) + ); + let _ = tx.send(Ok(Bytes::from(error_chunk))); + } + + // Send DONE marker + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + }); + + // Create response with SSE headers + let stream = UnboundedReceiverStream::new(rx); + let mut response = Response::new(Body::from_stream(stream)); + *response.status_mut() = StatusCode::OK; + response + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + response + .headers_mut() + .insert("Cache-Control", HeaderValue::from_static("no-cache")); + response + .headers_mut() + .insert("Connection", HeaderValue::from_static("keep-alive")); + response + } + + /// Process streaming chunks and send SSE events + async fn process_streaming_chunks( + router: &GrpcRouter, + grpc_stream: &mut (impl tokio_stream::Stream> + + Unpin), + request_id: String, + model: String, + stop_params: (Option, Option>, bool, bool), + original_request: ChatCompletionRequest, + tx: &mpsc::UnboundedSender>, + ) -> Result<(), String> { + // Extract request parameters + let separate_reasoning = original_request.separate_reasoning; + let tool_choice = &original_request.tool_choice; + let tools = &original_request.tools; + let history_tool_calls_count = Self::get_history_tool_calls_count(&original_request); + let stream_options = &original_request.stream_options; + + // Phase 1: Initialize state tracking (per-index for n>1 support) + let mut is_firsts: HashMap = HashMap::new(); + let mut stream_buffers: HashMap = HashMap::new(); + let mut finish_reasons: HashMap = HashMap::new(); + let mut matched_stops: HashMap> = HashMap::new(); + let mut prompt_tokens: HashMap = HashMap::new(); + let mut completion_tokens: HashMap = HashMap::new(); + let mut cached_tokens: HashMap = HashMap::new(); + + // Parser state (lazy initialization per index) + type PooledReasoningParser = + Arc>>; + let mut reasoning_parsers: HashMap = HashMap::new(); + + type PooledToolParser = Arc>>; + let mut tool_parsers: HashMap = HashMap::new(); + let mut has_tool_calls: HashMap = HashMap::new(); + + // Create stop decoder + let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params; + let mut stop_decoder = router.create_stop_decoder( + stop.as_ref(), + stop_token_ids.as_ref(), + skip_special_tokens, + no_stop_trim, + ); + + let created = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + // Phase 2: Main streaming loop while let Some(response) = grpc_stream.next().await { - let gen_response = match response { - Ok(resp) => resp, - Err(e) => { - error!("Stream error: {}", e); - break; - } - }; + let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; match gen_response.response { Some(Chunk(chunk)) => { - // Process tokens and check if we should stop - let (chunk_text, should_stop) = + let index = chunk.index; + + // Process tokens through stop decoder + let (chunk_text, _should_stop) = Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids); - decoded_text.push_str(&chunk_text); - if should_stop { - break; + + if chunk_text.is_empty() { + continue; + } + + // Process logprobs if present + let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs { + match router.convert_proto_to_openai_logprobs(proto_logprobs) { + Ok(logprobs) => Some(logprobs), + Err(e) => { + warn!("Failed to process logprobs: {}", e); + None + } + } + } else { + None + }; + + // Initialize stream buffer if first time + let stream_buffer = stream_buffers.entry(index).or_default(); + + // Send first chunk with role + if is_firsts.get(&index).copied().unwrap_or(true) { + let first_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk)))) + .map_err(|_| "Failed to send first chunk".to_string())?; + is_firsts.insert(index, false); + } + + // Calculate delta + let mut delta = chunk_text; + stream_buffer.push_str(&delta); + + // Reasoning content handling + if separate_reasoning { + let (normal_text, reasoning_chunk) = router.process_reasoning_stream( + &delta, + index, + &mut reasoning_parsers, + &request_id, + &model, + created, + ); + if let Some(chunk) = reasoning_chunk { + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) + .map_err(|_| "Failed to send reasoning chunk".to_string())?; + } + delta = normal_text; + } + + // Tool call handling + let tool_choice_enabled = + !matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None))); + + if tool_choice_enabled && tools.is_some() { + let (should_skip, tool_chunks) = router + .process_tool_calls_stream( + &delta, + index, + &mut tool_parsers, + &mut has_tool_calls, + tools.as_ref().unwrap(), + &request_id, + &model, + created, + history_tool_calls_count, + ) + .await; + + for chunk in tool_chunks { + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) + .map_err(|_| "Failed to send tool call chunk".to_string())?; + } + + if should_skip { + continue; + } + } + + // Regular content emission + if !delta.is_empty() { + let content_chunk = Self::create_content_chunk( + delta, + index, + &request_id, + &model, + created, + choice_logprobs, + ); + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk)))) + .map_err(|_| "Failed to send content chunk".to_string())?; } - continue; } - Some(Complete(_complete)) => { + Some(Complete(complete)) => { // Flush any remaining text if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() { if !text.is_empty() { - decoded_text.push_str(&text); - debug!("Flushed text: {}", text); + let index = complete.index; + let stream_buffer = stream_buffers.entry(index).or_default(); + stream_buffer.push_str(&text); + + let content_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(text), + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&content_chunk) + .map_err(|e| format!("Failed to serialize content chunk: {}", e))?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send flushed content".to_string())?; } } + + // Store metadata + let index = complete.index; + prompt_tokens.insert(index, complete.prompt_tokens as u32); + completion_tokens.insert(index, complete.completion_tokens as u32); + cached_tokens.insert(index, complete.cached_tokens as u32); + finish_reasons.insert(index, complete.finish_reason.clone()); + + // Extract matched_stop + let matched_stop_value = match &complete.matched_stop { + Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(Value::Number(serde_json::Number::from(*token_id))) + } + Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { + Some(Value::String(stop_str.clone())) + } + None => None, + }; + matched_stops.insert(index, matched_stop_value); + break; } Some(Error(error)) => { - error!("Generation error: {}", error.message); - break; + return Err(error.message); } None => continue, } } - // TODO: Replace with proper SSE streaming response - // For now, return the complete decoded text - (StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response() + // Phase 3: Check unstreamed tool args + // Check if parsers have any remaining arguments that haven't been streamed yet + for (index, parser) in &tool_parsers { + let parser_guard = parser.lock().await; + if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() { + for tool_call_item in unstreamed_items { + let tool_call_delta = ToolCallDelta { + index: tool_call_item.tool_index as u32, + id: None, + tool_type: None, // No type for argument deltas + function: Some(FunctionCallDelta { + name: None, // No name for argument deltas + arguments: if !tool_call_item.parameters.is_empty() { + Some(tool_call_item.parameters) + } else { + None + }, + }), + }; + + let tool_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index: *index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(vec![tool_call_delta]), + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&tool_chunk) + .map_err(|e| format!("Failed to serialize tool chunk: {}", e))?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send unstreamed tool args".to_string())?; + } + } + } + + // Phase 4: Finish reason chunks + for (index, finish_reason) in finish_reasons.iter() { + let final_finish_reason = + if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" { + "tool_calls".to_string() + } else { + finish_reason.clone() + }; + + let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone()); + + let finish_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index: *index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: Some(final_finish_reason), + matched_stop: matched_stop_value, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&finish_chunk) + .map_err(|e| format!("Failed to serialize finish chunk: {}", e))?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send finish chunk".to_string())?; + } + + // Phase 5: Usage chunk + if let Some(stream_opts) = stream_options { + if stream_opts.include_usage.unwrap_or(false) { + let total_prompt: u32 = prompt_tokens.values().sum(); + let total_completion: u32 = completion_tokens.values().sum(); + + let usage_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![], + usage: Some(Usage { + prompt_tokens: total_prompt, + completion_tokens: total_completion, + total_tokens: total_prompt + total_completion, + completion_tokens_details: None, + }), + }; + + let sse_chunk = serde_json::to_string(&usage_chunk) + .map_err(|e| format!("Failed to serialize usage chunk: {}", e))?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send usage chunk".to_string())?; + } + } + + Ok(()) } /// Submit request and handle non-streaming response for chat completions route @@ -1082,10 +1728,17 @@ impl GrpcRouter { } // Process each response into a ChatChoice + let history_tool_calls_count = Self::get_history_tool_calls_count(original_request); let mut choices = Vec::new(); for (index, complete) in all_responses.iter().enumerate() { match self - .process_single_choice(complete, index, original_request, &mut stop_decoder) + .process_single_choice( + complete, + index, + original_request, + &mut stop_decoder, + history_tool_calls_count, + ) .await { Ok(choice) => choices.push(choice), @@ -1216,11 +1869,12 @@ impl GrpcRouter { decoded_text.push_str(&t); } - let output_ids = complete.output_ids.clone(); + let output_ids = std::mem::take(&mut complete.output_ids); + let finish_reason = std::mem::take(&mut complete.finish_reason); // Build base meta_info using json! macro let mut meta_info = json!({ - "finish_reason": complete.finish_reason.clone(), + "finish_reason": finish_reason, "prompt_tokens": complete.prompt_tokens, "completion_tokens": complete.completion_tokens, "cached_tokens": complete.cached_tokens, @@ -1269,9 +1923,13 @@ impl GrpcRouter { }) .collect(); - // Build ChatLogProbsContent for each token - for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() { - let token_text = token_texts.get(i).cloned().unwrap_or_default(); + // Build ChatLogProbsContent for each token (consume iterator to avoid clones) + for (i, (&logprob, token_text)) in proto_logprobs + .token_logprobs + .iter() + .zip(token_texts.into_iter()) + .enumerate() + { let bytes = Some(token_text.as_bytes().to_vec()); // Build top_logprobs for this position @@ -1324,6 +1982,7 @@ impl GrpcRouter { index: usize, original_request: &ChatCompletionRequest, stop_decoder: &mut StopSequenceDecoder, + history_tool_calls_count: usize, ) -> Result { stop_decoder.reset(); // Decode tokens @@ -1401,7 +2060,11 @@ impl GrpcRouter { self.parse_json_schema_response(&processed_text, &original_request.tool_choice); } else { (tool_calls, processed_text) = self - .parse_with_model_parser(&processed_text, &original_request.model) + .parse_tool_calls( + &processed_text, + &original_request.model, + history_tool_calls_count, + ) .await; } } @@ -1686,7 +2349,6 @@ mod tests { content: Some("Assistant response".to_string()), name: None, tool_calls: None, - function_call: None, reasoning_content: None, }]; diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 2dd39e279..4dcc71fb1 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -15,7 +15,7 @@ use crate::{ }, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, }, - reasoning_parser::ParserFactory, + reasoning_parser::ReasoningParserFactory, routers::{router_manager::RouterManager, RouterTrait}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, @@ -45,7 +45,7 @@ pub struct AppContext { pub router_config: RouterConfig, pub rate_limiter: Arc, pub tokenizer: Option>, - pub reasoning_parser_factory: Option, + pub reasoning_parser_factory: Option, pub tool_parser_factory: Option, pub worker_registry: Arc, pub policy_registry: Arc, @@ -79,7 +79,7 @@ impl AppContext { tokenizer_factory::create_tokenizer(&tokenizer_path) .map_err(|e| format!("Failed to create tokenizer: {e}"))?, ); - let reasoning_parser_factory = Some(ParserFactory::new()); + let reasoning_parser_factory = Some(ReasoningParserFactory::new()); let tool_parser_factory = Some(ToolParserFactory::new()); (tokenizer, reasoning_parser_factory, tool_parser_factory) diff --git a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs index 94364e3a1..2f28aae06 100644 --- a/sgl-router/src/tool_parser/parsers/deepseek_parser.rs +++ b/sgl-router/src/tool_parser/parsers/deepseek_parser.rs @@ -123,12 +123,7 @@ impl DeepSeekParser { let arguments = serde_json::to_string(&args) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; - // Generate ID - let id = format!("deepseek_call_{}", uuid::Uuid::new_v4()); - Ok(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: func_name.to_string(), arguments, @@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser { fn detect_format(&self, text: &str) -> bool { self.has_tool_markers(text) } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs index 86c54ee6b..7d134f4eb 100644 --- a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs +++ b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs @@ -129,12 +129,7 @@ impl Glm4MoeParser { let arguments_str = serde_json::to_string(&arguments) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; - // Generate ID - let id = format!("glm4_call_{}", uuid::Uuid::new_v4()); - Ok(Some(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: func_name.to_string(), arguments: arguments_str, @@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser { fn detect_format(&self, text: &str) -> bool { self.has_tool_markers(text) } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs index 73a43efc6..5769315a3 100644 --- a/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs +++ b/sgl-router/src/tool_parser/parsers/gpt_oss_parser.rs @@ -113,12 +113,7 @@ impl ToolParser for GptOssParser { } }; - // Generate unique ID - let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4()); - tools.push(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: function_name, arguments, diff --git a/sgl-router/src/tool_parser/parsers/helpers.rs b/sgl-router/src/tool_parser/parsers/helpers.rs index 42ab4e416..46dcd71c6 100644 --- a/sgl-router/src/tool_parser/parsers/helpers.rs +++ b/sgl-router/src/tool_parser/parsers/helpers.rs @@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap { .collect() } +/// Get unstreamed tool call arguments +/// Returns tool call items for arguments that have been parsed but not yet streamed +/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk +pub fn get_unstreamed_args( + prev_tool_call_arr: &[Value], + streamed_args_for_tool: &[String], +) -> Option> { + // Check if we have tool calls being tracked + if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() { + return None; + } + + // Get the last tool call that was being processed + let tool_index = prev_tool_call_arr.len() - 1; + if tool_index >= streamed_args_for_tool.len() { + return None; + } + + // Get expected vs actual arguments + let expected_args = prev_tool_call_arr[tool_index].get("arguments")?; + let expected_str = serde_json::to_string(expected_args).ok()?; + let actual_str = &streamed_args_for_tool[tool_index]; + + // Check if there are remaining arguments to send + let remaining = if expected_str.starts_with(actual_str) { + &expected_str[actual_str.len()..] + } else { + return None; + }; + + if remaining.is_empty() { + return None; + } + + // Return the remaining arguments as a ToolCallItem + Some(vec![ToolCallItem { + tool_index, + name: None, // No name for argument deltas + parameters: remaining.to_string(), + }]) +} + /// Check if a buffer ends with a partial occurrence of a token /// Returns Some(length) if there's a partial match, None otherwise pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option { diff --git a/sgl-router/src/tool_parser/parsers/json_parser.rs b/sgl-router/src/tool_parser/parsers/json_parser.rs index 0ea2e85f0..3a005e2cd 100644 --- a/sgl-router/src/tool_parser/parsers/json_parser.rs +++ b/sgl-router/src/tool_parser/parsers/json_parser.rs @@ -8,7 +8,7 @@ use crate::tool_parser::{ parsers::helpers, partial_json::PartialJson, traits::ToolParser, - types::{FunctionCall, StreamingParseResult, ToolCall}, + types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, }; /// JSON format parser for tool calls @@ -136,16 +136,7 @@ impl JsonParser { let arguments = serde_json::to_string(args) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; - // Generate a unique ID if not provided - let id = obj - .get("id") - .and_then(|v| v.as_str()) - .map(String::from) - .unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4())); - Ok(Some(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: name.to_string(), arguments, @@ -274,4 +265,8 @@ impl ToolParser for JsonParser { let trimmed = text.trim(); (trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#) } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs index 44fede1ea..123c4d0f5 100644 --- a/sgl-router/src/tool_parser/parsers/kimik2_parser.rs +++ b/sgl-router/src/tool_parser/parsers/kimik2_parser.rs @@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser { // Try to parse JSON arguments match serde_json::from_str::(function_args) { Ok(_) => { - // Generate unique ID - let id = format!("kimi_call_{}", uuid::Uuid::new_v4()); - tools.push(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: func_name, arguments: function_args.to_string(), @@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser { fn detect_format(&self, text: &str) -> bool { self.has_tool_markers(text) || text.contains("<|tool_call_begin|>") } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/parsers/llama_parser.rs b/sgl-router/src/tool_parser/parsers/llama_parser.rs index 37b49b40a..5634aa81e 100644 --- a/sgl-router/src/tool_parser/parsers/llama_parser.rs +++ b/sgl-router/src/tool_parser/parsers/llama_parser.rs @@ -1,6 +1,5 @@ use async_trait::async_trait; use serde_json::Value; -use uuid; use crate::protocols::spec::Tool; @@ -84,16 +83,7 @@ impl LlamaParser { let arguments = serde_json::to_string(parameters) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; - // Generate a unique ID for Llama calls - let id = obj - .get("id") - .and_then(|v| v.as_str()) - .map(String::from) - .unwrap_or_else(|| format!("llama_call_{}", uuid::Uuid::new_v4())); - Ok(Some(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: name.to_string(), arguments, @@ -243,4 +233,8 @@ impl ToolParser for LlamaParser { text.contains("<|python_tag|>") || (text.trim_start().starts_with('{') && text.contains(r#""name""#)) } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/parsers/mistral_parser.rs b/sgl-router/src/tool_parser/parsers/mistral_parser.rs index ae5d3511e..2148a966f 100644 --- a/sgl-router/src/tool_parser/parsers/mistral_parser.rs +++ b/sgl-router/src/tool_parser/parsers/mistral_parser.rs @@ -146,16 +146,7 @@ impl MistralParser { let arguments = serde_json::to_string(args) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; - // Generate unique ID - let id = obj - .get("id") - .and_then(|v| v.as_str()) - .map(String::from) - .unwrap_or_else(|| format!("mistral_call_{}", uuid::Uuid::new_v4())); - Ok(Some(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: name.to_string(), arguments, @@ -266,4 +257,8 @@ impl ToolParser for MistralParser { fn detect_format(&self, text: &str) -> bool { self.has_tool_markers(text) } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs index 5505b12d7..8eeaecd41 100644 --- a/sgl-router/src/tool_parser/parsers/pythonic_parser.rs +++ b/sgl-router/src/tool_parser/parsers/pythonic_parser.rs @@ -244,7 +244,7 @@ fn parse_python_expression(source: &str) -> ToolParserResult { } } -fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult { +fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult { match expr { Expr::Call(call_expr) => { if !call_expr.args.is_empty() { @@ -277,8 +277,6 @@ fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult { let arguments_string = serde_json::to_string(&arguments_json)?; Ok(ToolCall { - id: format!("call-{}", index + 1), - r#type: "function".to_string(), function: FunctionCall { name: function_name, arguments: arguments_string, diff --git a/sgl-router/src/tool_parser/parsers/qwen_parser.rs b/sgl-router/src/tool_parser/parsers/qwen_parser.rs index 230c6e39b..f6de474f4 100644 --- a/sgl-router/src/tool_parser/parsers/qwen_parser.rs +++ b/sgl-router/src/tool_parser/parsers/qwen_parser.rs @@ -88,16 +88,7 @@ impl QwenParser { let arguments = serde_json::to_string(args) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; - // Generate unique ID - let id = obj - .get("id") - .and_then(|v| v.as_str()) - .map(String::from) - .unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4())); - Ok(Some(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: name.to_string(), arguments, @@ -255,4 +246,8 @@ impl ToolParser for QwenParser { fn detect_format(&self, text: &str) -> bool { self.has_tool_markers(text) } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/parsers/step3_parser.rs b/sgl-router/src/tool_parser/parsers/step3_parser.rs index 6135c3366..319b243ac 100644 --- a/sgl-router/src/tool_parser/parsers/step3_parser.rs +++ b/sgl-router/src/tool_parser/parsers/step3_parser.rs @@ -400,12 +400,7 @@ impl Step3Parser { let arguments_str = serde_json::to_string(¶meters) .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; - // Generate ID - let id = format!("step3_call_{}", uuid::Uuid::new_v4()); - Ok(Some(ToolCall { - id, - r#type: "function".to_string(), function: FunctionCall { name: func_name.to_string(), arguments: arguments_str, @@ -561,4 +556,8 @@ impl ToolParser for Step3Parser { fn detect_format(&self, text: &str) -> bool { self.has_tool_markers(text) } + + fn get_unstreamed_tool_args(&self) -> Option> { + helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) + } } diff --git a/sgl-router/src/tool_parser/tests.rs b/sgl-router/src/tool_parser/tests.rs index 1840d42b6..d5171abdf 100644 --- a/sgl-router/src/tool_parser/tests.rs +++ b/sgl-router/src/tool_parser/tests.rs @@ -31,8 +31,6 @@ async fn test_tool_parser_factory_model_mapping() { #[test] fn test_tool_call_serialization() { let tool_call = ToolCall { - id: "call-123".to_string(), - r#type: "function".to_string(), function: FunctionCall { name: "search".to_string(), arguments: r#"{"query": "rust programming"}"#.to_string(), @@ -40,13 +38,15 @@ fn test_tool_call_serialization() { }; let json = serde_json::to_string(&tool_call).unwrap(); - assert!(json.contains("call-123")); assert!(json.contains("search")); assert!(json.contains("rust programming")); let parsed: ToolCall = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed.id, "call-123"); assert_eq!(parsed.function.name, "search"); + assert_eq!( + parsed.function.arguments, + r#"{"query": "rust programming"}"# + ); } #[test] diff --git a/sgl-router/src/tool_parser/traits.rs b/sgl-router/src/tool_parser/traits.rs index e5a6524a6..ee6d00c87 100644 --- a/sgl-router/src/tool_parser/traits.rs +++ b/sgl-router/src/tool_parser/traits.rs @@ -32,6 +32,12 @@ pub trait ToolParser: Send + Sync { fn as_token_parser(&self) -> Option<&dyn TokenToolParser> { None } + + /// Get unstreamed tool call arguments + /// Returns tool call items for arguments that have been parsed but not yet streamed + fn get_unstreamed_tool_args(&self) -> Option> { + None + } } /// Trait for partial JSON parsing diff --git a/sgl-router/src/tool_parser/types.rs b/sgl-router/src/tool_parser/types.rs index 4183ca6cb..8157a44e2 100644 --- a/sgl-router/src/tool_parser/types.rs +++ b/sgl-router/src/tool_parser/types.rs @@ -1,13 +1,8 @@ use serde::{Deserialize, Serialize}; -/// Parsed tool call from model output (OpenAI format) +/// Parsed tool call from model output #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolCall { - /// Unique identifier for the tool call - pub id: String, - /// Type of tool call (currently always "function") - #[serde(rename = "type")] - pub r#type: String, /// Function call details pub function: FunctionCall, } diff --git a/sgl-router/tests/chat_template_integration.rs b/sgl-router/tests/chat_template_integration.rs index 572e539eb..ac25a3f10 100644 --- a/sgl-router/tests/chat_template_integration.rs +++ b/sgl-router/tests/chat_template_integration.rs @@ -181,7 +181,6 @@ fn test_chatml_template() { content: Some("Hi there!".to_string()), name: None, tool_calls: None, - function_call: None, reasoning_content: None, }, spec::ChatMessage::User { diff --git a/sgl-router/tests/chat_template_loading.rs b/sgl-router/tests/chat_template_loading.rs index e43b857a9..b3a5a3e70 100644 --- a/sgl-router/tests/chat_template_loading.rs +++ b/sgl-router/tests/chat_template_loading.rs @@ -68,7 +68,6 @@ mod tests { content: Some("Hi there".to_string()), name: None, tool_calls: None, - function_call: None, reasoning_content: None, }, ]; @@ -213,7 +212,6 @@ mod tests { content: Some("World".to_string()), name: None, tool_calls: None, - function_call: None, reasoning_content: None, }, ];