diff --git a/python/sglang/srt/grpc/sglang_scheduler.proto b/python/sglang/srt/grpc/sglang_scheduler.proto index f52f50d2a..9eeb0f00f 100644 --- a/python/sglang/srt/grpc/sglang_scheduler.proto +++ b/python/sglang/srt/grpc/sglang_scheduler.proto @@ -36,9 +36,9 @@ message SamplingParams { float presence_penalty = 6; float repetition_penalty = 7; - int32 max_new_tokens = 8; + optional int32 max_new_tokens = 8; repeated string stop = 9; - repeated int32 stop_token_ids = 10; + repeated uint32 stop_token_ids = 10; bool skip_special_tokens = 11; bool spaces_between_special_tokens = 12; @@ -98,7 +98,7 @@ message GenerateRequest { bool return_logprob = 5; int32 logprob_start_len = 6; int32 top_logprobs_num = 7; - repeated int32 token_ids_logprob = 8; + repeated uint32 token_ids_logprob = 8; bool return_hidden_states = 9; // For disaggregated serving @@ -129,7 +129,7 @@ message GenerateRequest { message TokenizedInput { string original_text = 1; // For reference - repeated int32 input_ids = 2; + repeated uint32 input_ids = 2; } message MultimodalInputs { @@ -167,7 +167,7 @@ message GenerateResponse { message GenerateStreamChunk { // Generated tokens (incremental chunk) - repeated int32 token_ids = 1; + repeated uint32 token_ids = 1; // Cumulative counts int32 prompt_tokens = 2; @@ -183,7 +183,7 @@ message GenerateStreamChunk { message GenerateComplete { // Final output - repeated int32 output_ids = 1; + repeated uint32 output_ids = 1; // Finish reason enum FinishReason { diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py index 1142104aa..6544f6afb 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.py +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py @@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc9\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xbb\x01\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x31\n\x08logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12K\n\rfinish_reason\x18\x02 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x35\n\x0c\x61ll_logprobs\x18\x06 \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xbb\x01\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x31\n\x08logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12K\n\rfinish_reason\x18\x02 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x35\n\x0c\x61ll_logprobs\x18\x06 \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -39,71 +39,71 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._loaded_options = None _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_options = b'8\001' _globals['_SAMPLINGPARAMS']._serialized_start=113 - _globals['_SAMPLINGPARAMS']._serialized_end=826 - _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=764 - _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=812 - _globals['_DISAGGREGATEDPARAMS']._serialized_start=828 - _globals['_DISAGGREGATEDPARAMS']._serialized_end=921 - _globals['_GENERATEREQUEST']._serialized_start=924 - _globals['_GENERATEREQUEST']._serialized_end=1557 - _globals['_TOKENIZEDINPUT']._serialized_start=1559 - _globals['_TOKENIZEDINPUT']._serialized_end=1617 - _globals['_MULTIMODALINPUTS']._serialized_start=1620 - _globals['_MULTIMODALINPUTS']._serialized_end=1831 - _globals['_GENERATERESPONSE']._serialized_start=1834 - _globals['_GENERATERESPONSE']._serialized_end=2061 - _globals['_GENERATESTREAMCHUNK']._serialized_start=2064 - _globals['_GENERATESTREAMCHUNK']._serialized_end=2251 - _globals['_GENERATECOMPLETE']._serialized_start=2254 - _globals['_GENERATECOMPLETE']._serialized_end=2639 - _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2563 - _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2639 - _globals['_GENERATEERROR']._serialized_start=2641 - _globals['_GENERATEERROR']._serialized_end=2716 - _globals['_LOGPROBS']._serialized_start=2719 - _globals['_LOGPROBS']._serialized_end=2851 - _globals['_TOPLOGPROBS']._serialized_start=2853 - _globals['_TOPLOGPROBS']._serialized_end=2922 - _globals['_HIDDENSTATES']._serialized_start=2924 - _globals['_HIDDENSTATES']._serialized_end=2987 - _globals['_EMBEDREQUEST']._serialized_start=2990 - _globals['_EMBEDREQUEST']._serialized_end=3320 - _globals['_EMBEDRESPONSE']._serialized_start=3323 - _globals['_EMBEDRESPONSE']._serialized_end=3480 - _globals['_EMBEDCOMPLETE']._serialized_start=3483 - _globals['_EMBEDCOMPLETE']._serialized_end=3646 - _globals['_EMBEDDING']._serialized_start=3648 - _globals['_EMBEDDING']._serialized_end=3690 - _globals['_EMBEDERROR']._serialized_start=3692 - _globals['_EMBEDERROR']._serialized_end=3752 - _globals['_HEALTHCHECKREQUEST']._serialized_start=3754 - _globals['_HEALTHCHECKREQUEST']._serialized_end=3832 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=3834 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=3889 - _globals['_ABORTREQUEST']._serialized_start=3891 - _globals['_ABORTREQUEST']._serialized_end=3941 - _globals['_ABORTRESPONSE']._serialized_start=3943 - _globals['_ABORTRESPONSE']._serialized_end=3992 - _globals['_LOADLORAREQUEST']._serialized_start=3994 - _globals['_LOADLORAREQUEST']._serialized_end=4067 - _globals['_LOADLORARESPONSE']._serialized_start=4069 - _globals['_LOADLORARESPONSE']._serialized_end=4141 - _globals['_UNLOADLORAREQUEST']._serialized_start=4143 - _globals['_UNLOADLORAREQUEST']._serialized_end=4182 - _globals['_UNLOADLORARESPONSE']._serialized_start=4184 - _globals['_UNLOADLORARESPONSE']._serialized_end=4238 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4240 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4359 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4361 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4418 - _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4420 - _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4465 - _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4467 - _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4533 - _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4535 - _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4600 - _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4602 - _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4662 - _globals['_SGLANGSCHEDULER']._serialized_start=4665 - _globals['_SGLANGSCHEDULER']._serialized_end=5047 + _globals['_SAMPLINGPARAMS']._serialized_end=850 + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=769 + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=817 + _globals['_DISAGGREGATEDPARAMS']._serialized_start=852 + _globals['_DISAGGREGATEDPARAMS']._serialized_end=945 + _globals['_GENERATEREQUEST']._serialized_start=948 + _globals['_GENERATEREQUEST']._serialized_end=1581 + _globals['_TOKENIZEDINPUT']._serialized_start=1583 + _globals['_TOKENIZEDINPUT']._serialized_end=1641 + _globals['_MULTIMODALINPUTS']._serialized_start=1644 + _globals['_MULTIMODALINPUTS']._serialized_end=1855 + _globals['_GENERATERESPONSE']._serialized_start=1858 + _globals['_GENERATERESPONSE']._serialized_end=2085 + _globals['_GENERATESTREAMCHUNK']._serialized_start=2088 + _globals['_GENERATESTREAMCHUNK']._serialized_end=2275 + _globals['_GENERATECOMPLETE']._serialized_start=2278 + _globals['_GENERATECOMPLETE']._serialized_end=2663 + _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2587 + _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2663 + _globals['_GENERATEERROR']._serialized_start=2665 + _globals['_GENERATEERROR']._serialized_end=2740 + _globals['_LOGPROBS']._serialized_start=2743 + _globals['_LOGPROBS']._serialized_end=2875 + _globals['_TOPLOGPROBS']._serialized_start=2877 + _globals['_TOPLOGPROBS']._serialized_end=2946 + _globals['_HIDDENSTATES']._serialized_start=2948 + _globals['_HIDDENSTATES']._serialized_end=3011 + _globals['_EMBEDREQUEST']._serialized_start=3014 + _globals['_EMBEDREQUEST']._serialized_end=3344 + _globals['_EMBEDRESPONSE']._serialized_start=3347 + _globals['_EMBEDRESPONSE']._serialized_end=3504 + _globals['_EMBEDCOMPLETE']._serialized_start=3507 + _globals['_EMBEDCOMPLETE']._serialized_end=3670 + _globals['_EMBEDDING']._serialized_start=3672 + _globals['_EMBEDDING']._serialized_end=3714 + _globals['_EMBEDERROR']._serialized_start=3716 + _globals['_EMBEDERROR']._serialized_end=3776 + _globals['_HEALTHCHECKREQUEST']._serialized_start=3778 + _globals['_HEALTHCHECKREQUEST']._serialized_end=3856 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=3858 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=3913 + _globals['_ABORTREQUEST']._serialized_start=3915 + _globals['_ABORTREQUEST']._serialized_end=3965 + _globals['_ABORTRESPONSE']._serialized_start=3967 + _globals['_ABORTRESPONSE']._serialized_end=4016 + _globals['_LOADLORAREQUEST']._serialized_start=4018 + _globals['_LOADLORAREQUEST']._serialized_end=4091 + _globals['_LOADLORARESPONSE']._serialized_start=4093 + _globals['_LOADLORARESPONSE']._serialized_end=4165 + _globals['_UNLOADLORAREQUEST']._serialized_start=4167 + _globals['_UNLOADLORAREQUEST']._serialized_end=4206 + _globals['_UNLOADLORARESPONSE']._serialized_start=4208 + _globals['_UNLOADLORARESPONSE']._serialized_end=4262 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4264 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4383 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4385 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4442 + _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4444 + _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4489 + _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4491 + _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4557 + _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4559 + _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4624 + _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4626 + _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4686 + _globals['_SGLANGSCHEDULER']._serialized_start=4689 + _globals['_SGLANGSCHEDULER']._serialized_end=5071 # @@protoc_insertion_point(module_scope) diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 0b270e8b4..36a980235 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -20,7 +20,7 @@ pub struct SglangSchedulerClient { impl SglangSchedulerClient { /// Create a new client and connect to the scheduler - pub async fn connect(endpoint: &str) -> Result> { + pub async fn connect(endpoint: &str) -> Result> { debug!("Connecting to SGLang scheduler at {}", endpoint); // Convert grpc:// to http:// for tonic @@ -41,10 +41,11 @@ impl SglangSchedulerClient { } /// Submit a generation request (returns streaming response) - pub async fn generate_stream( + pub async fn generate( &mut self, req: proto::GenerateRequest, - ) -> Result, Box> { + ) -> Result, Box> + { let request = Request::new(req); let response = self.client.generate(request).await?; Ok(response.into_inner()) @@ -53,7 +54,7 @@ impl SglangSchedulerClient { /// Perform health check pub async fn health_check( &mut self, - ) -> Result> { + ) -> Result> { debug!("Sending health check request"); let request = Request::new(proto::HealthCheckRequest { tokenized: Some(proto::TokenizedInput { @@ -72,7 +73,7 @@ impl SglangSchedulerClient { &mut self, request_id: String, reason: String, - ) -> Result<(), Box> { + ) -> Result<(), Box> { let request = Request::new(proto::AbortRequest { request_id, reason }); self.client.abort(request).await?; @@ -85,7 +86,7 @@ impl SglangSchedulerClient { request_id: String, body: &ChatCompletionRequest, processed_text: String, - token_ids: Vec, + token_ids: Vec, multimodal_inputs: Option, tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) ) -> Result { @@ -153,6 +154,8 @@ impl SglangSchedulerClient { stop: stop_sequences, stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), skip_special_tokens, + ignore_eos: request.ignore_eos, + no_stop_trim: request.no_stop_trim, n: request.n.unwrap_or(1) as i32, constraint: self.build_constraint(request, tool_call_constraint)?, ..Default::default() diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index 2892caec2..9eeb0f00f 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -38,7 +38,7 @@ message SamplingParams { optional int32 max_new_tokens = 8; repeated string stop = 9; - repeated int32 stop_token_ids = 10; + repeated uint32 stop_token_ids = 10; bool skip_special_tokens = 11; bool spaces_between_special_tokens = 12; @@ -98,7 +98,7 @@ message GenerateRequest { bool return_logprob = 5; int32 logprob_start_len = 6; int32 top_logprobs_num = 7; - repeated int32 token_ids_logprob = 8; + repeated uint32 token_ids_logprob = 8; bool return_hidden_states = 9; // For disaggregated serving @@ -129,7 +129,7 @@ message GenerateRequest { message TokenizedInput { string original_text = 1; // For reference - repeated int32 input_ids = 2; + repeated uint32 input_ids = 2; } message MultimodalInputs { @@ -167,7 +167,7 @@ message GenerateResponse { message GenerateStreamChunk { // Generated tokens (incremental chunk) - repeated int32 token_ids = 1; + repeated uint32 token_ids = 1; // Cumulative counts int32 prompt_tokens = 2; @@ -183,7 +183,7 @@ message GenerateStreamChunk { message GenerateComplete { // Final output - repeated int32 output_ids = 1; + repeated uint32 output_ids = 1; // Finish reason enum FinishReason { diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index c6de7075f..0e0f6667d 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -313,7 +313,7 @@ pub struct ChatCompletionRequest { /// Specific token IDs to use as stop conditions #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, + pub stop_token_ids: Option>, /// Skip trimming stop tokens from output #[serde(default)] @@ -564,7 +564,7 @@ pub struct CompletionRequest { /// Specific token IDs to use as stop conditions #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, + pub stop_token_ids: Option>, /// Skip trimming stop tokens from output #[serde(default)] @@ -1864,7 +1864,7 @@ pub struct SamplingParams { #[serde(skip_serializing_if = "Option::is_none")] pub min_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, + pub stop_token_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub no_stop_trim: Option, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index aaaab25c1..e169c5f65 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -17,19 +17,20 @@ use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::metrics::RouterMetrics; use crate::policies::PolicyRegistry; use crate::protocols::spec::ChatMessage; -use crate::protocols::spec::{ChatCompletionRequest, StringOrArray}; use crate::protocols::spec::{ - CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, - ResponsesRequest, Tool, ToolChoice, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolChoice, }; use crate::reasoning_parser::ParserFactory; use crate::routers::RouterTrait; use crate::server::AppContext; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; +use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoderBuilder}; use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::HuggingFaceTokenizer; use crate::tool_parser::ParserRegistry; use serde_json::Value; +use tokio_stream::StreamExt; use uuid::Uuid; // Data structures for processing @@ -182,7 +183,7 @@ impl GrpcRouter { request_id, body, processed_messages.text.clone(), - token_ids.into_iter().map(|id| id as i32).collect(), + token_ids, processed_messages.multimodal_inputs, tool_call_constraint, // Pass the full tuple (type, value) ) { @@ -479,28 +480,225 @@ impl GrpcRouter { None } - /// Placeholder for streaming handler (to be implemented in Phase 2) - async fn handle_streaming_chat( + /// Create a StopSequenceDecoder from the chat completion request + fn create_stop_decoder( &self, - _client: SglangSchedulerClient, - _request: proto::GenerateRequest, - _original_request: &ChatCompletionRequest, - ) -> Response { - (StatusCode::NOT_IMPLEMENTED, "Streaming not yet implemented").into_response() + original_request: &ChatCompletionRequest, + ) -> crate::tokenizer::stop::StopSequenceDecoder { + // Extract stop sequences from request + let stop_sequences: Vec = match &original_request.stop { + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + }; + + // Build stop sequence decoder + let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone()) + .skip_special_tokens(original_request.skip_special_tokens); + + // Add stop sequences (visible if no_stop_trim is true, hidden otherwise) + for seq in stop_sequences { + builder = if original_request.no_stop_trim { + builder.visible_stop_sequence(seq) + } else { + builder.stop_sequence(seq) + }; + } + + // Add stop token IDs (visible if no_stop_trim is true, hidden otherwise) + if let Some(stop_token_ids) = &original_request.stop_token_ids { + for &token_id in stop_token_ids { + builder = if original_request.no_stop_trim { + builder.visible_stop_token(token_id) + } else { + builder.stop_token(token_id) + }; + } + } + + builder.build() } - /// Placeholder for non-streaming handler (to be implemented in Phase 3) + /// Process a chunk of tokens through the stop decoder + fn process_chunk_tokens( + stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder, + token_ids: &[u32], + ) -> (String, bool) { + let mut chunk_text = String::new(); + + for &token_id in token_ids { + match stop_decoder.process_token(token_id).unwrap_or_else(|e| { + debug!( + "Error processing token {}: {}. Treating as Held.", + token_id, e + ); + SequenceDecoderOutput::Held + }) { + SequenceDecoderOutput::Text(text) => { + chunk_text.push_str(&text); + } + SequenceDecoderOutput::StoppedWithText(text) => { + chunk_text.push_str(&text); + return (chunk_text, true); // Return text and signal to stop + } + SequenceDecoderOutput::Stopped => { + return (chunk_text, true); // Return text and signal to stop + } + SequenceDecoderOutput::Held => { + // Text held for potential stop sequence match + } + } + } + (chunk_text, false) // Return text and continue processing + } + + /// Submit request and handle streaming response for chat completions route + async fn handle_streaming_chat( + &self, + mut client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &ChatCompletionRequest, + ) -> Response { + let mut stop_decoder = self.create_stop_decoder(original_request); + + // Process streaming tokens + let mut grpc_stream = match client.generate(request).await { + Ok(stream) => stream, + Err(e) => { + error!("Failed to start generation: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Generation failed: {}", e), + ) + .into_response(); + } + }; + + let mut decoded_text = String::new(); + + while let Some(response) = grpc_stream.next().await { + let gen_response = match response { + Ok(resp) => resp, + Err(e) => { + error!("Stream error: {}", e); + break; + } + }; + + match gen_response.response { + Some(proto::generate_response::Response::Chunk(chunk)) => { + // Process tokens and check if we should stop + let (chunk_text, should_stop) = + Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids); + decoded_text.push_str(&chunk_text); + if should_stop { + break; + } + continue; + } + Some(proto::generate_response::Response::Complete(_complete)) => { + // Flush any remaining text + if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() { + if !text.is_empty() { + decoded_text.push_str(&text); + debug!("Flushed text: {}", text); + } + } + break; + } + Some(proto::generate_response::Response::Error(error)) => { + error!("Generation error: {}", error.message); + break; + } + None => continue, + } + } + + // TODO: Replace with proper SSE streaming response + // For now, return the complete decoded text + (StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response() + } + + /// Submit request and handle non-streaming response for chat completions route async fn handle_non_streaming_chat( &self, - _client: SglangSchedulerClient, - _request: proto::GenerateRequest, - _original_request: &ChatCompletionRequest, + mut client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &ChatCompletionRequest, ) -> Response { - ( - StatusCode::NOT_IMPLEMENTED, - "Non-streaming not yet implemented", - ) - .into_response() + let mut stop_decoder = self.create_stop_decoder(original_request); + + // Small helpers to log + return a uniform 500 + let fail_str = |msg: &'static str| -> Response { + error!("{}", msg); + (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response() + }; + let fail_fmt = |prefix: &str, e: &dyn std::fmt::Display| -> Response { + error!("{}{}", prefix, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("{}{}", prefix, e), + ) + .into_response() + }; + + // Start generation + let mut stream = match client.generate(request).await { + Ok(s) => s, + Err(e) => return fail_fmt("Failed to start generation: ", &e), + }; + + // Get the single Complete response + let gen_response = match stream.next().await { + Some(Ok(r)) => r, + Some(Err(e)) => return fail_fmt("Failed to get GenerateResponse: ", &e), + None => return fail_str("No response from server"), + }; + + // Extract the expected variant early + let complete = match gen_response.response { + Some(proto::generate_response::Response::Complete(c)) => c, + Some(proto::generate_response::Response::Error(err)) => { + error!("Generation failed: {}", err.message); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Generation failed: {}", err.message), + ) + .into_response(); + } + Some(proto::generate_response::Response::Chunk(_)) => { + return fail_str("Unexpected chunk response for non-streaming request") + } + None => return fail_str("Empty response from server"), + }; + + // Decode tokens + let outputs = match stop_decoder.process_tokens(&complete.output_ids) { + Ok(o) => o, + Err(e) => return fail_fmt("Failed to process tokens: ", &e), + }; + + // Accumulate text with early breaks + let mut final_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => final_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + final_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + final_text.push_str(&t); + } + + // TODO: Create proper OpenAI-compatible response + (StatusCode::OK, format!("Final text: {}", final_text)).into_response() } }