From 7ff740a6cead8128c6249a6b178f5d8b66ee6fbf Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Fri, 3 Oct 2025 01:48:15 +0800 Subject: [PATCH] Remove dp balance metadata and minimul token balance. (#11170) --- python/sglang/srt/entrypoints/engine.py | 1 - python/sglang/srt/grpc/sglang_scheduler.proto | 5 +- .../sglang/srt/grpc/sglang_scheduler_pb2.py | 128 +++++++++--------- .../sglang/srt/grpc/sglang_scheduler_pb2.pyi | 6 +- .../srt/managers/data_parallel_controller.py | 53 ++------ python/sglang/srt/managers/io_struct.py | 5 - python/sglang/srt/managers/scheduler.py | 9 +- .../srt/managers/scheduler_metrics_mixin.py | 96 ------------- python/sglang/srt/managers/utils.py | 43 ------ sgl-router/src/proto/sglang_scheduler.proto | 5 +- test/srt/test_dp_attention.py | 42 ------ 11 files changed, 77 insertions(+), 316 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 840ed332d..96485583f 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -812,7 +812,6 @@ def _launch_subprocesses( pp_rank, None, writer, - None, ), ) diff --git a/python/sglang/srt/grpc/sglang_scheduler.proto b/python/sglang/srt/grpc/sglang_scheduler.proto index ee0135b1f..1d52e54c5 100644 --- a/python/sglang/srt/grpc/sglang_scheduler.proto +++ b/python/sglang/srt/grpc/sglang_scheduler.proto @@ -120,11 +120,8 @@ message GenerateRequest { // Data parallel routing int32 data_parallel_rank = 16; - // For load balancing - int32 dp_balance_id = 17; - // Whether client wants streaming response - bool stream = 18; + bool stream = 17; } message TokenizedInput { diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py index 209326b1f..2dc5ac321 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.py +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py @@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x86\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\"\x8c\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbsB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x86\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\"\x8c\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbsB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -45,67 +45,67 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_DISAGGREGATEDPARAMS']._serialized_start=852 _globals['_DISAGGREGATEDPARAMS']._serialized_end=945 _globals['_GENERATEREQUEST']._serialized_start=948 - _globals['_GENERATEREQUEST']._serialized_end=1581 - _globals['_TOKENIZEDINPUT']._serialized_start=1583 - _globals['_TOKENIZEDINPUT']._serialized_end=1641 - _globals['_MULTIMODALINPUTS']._serialized_start=1644 - _globals['_MULTIMODALINPUTS']._serialized_end=1855 - _globals['_GENERATERESPONSE']._serialized_start=1858 - _globals['_GENERATERESPONSE']._serialized_end=2085 - _globals['_GENERATESTREAMCHUNK']._serialized_start=2088 - _globals['_GENERATESTREAMCHUNK']._serialized_end=2350 - _globals['_GENERATECOMPLETE']._serialized_start=2353 - _globals['_GENERATECOMPLETE']._serialized_end=2749 - _globals['_GENERATEERROR']._serialized_start=2751 - _globals['_GENERATEERROR']._serialized_end=2826 - _globals['_OUTPUTLOGPROBS']._serialized_start=2828 - _globals['_OUTPUTLOGPROBS']._serialized_end=2945 - _globals['_INPUTLOGPROBS']._serialized_start=2948 - _globals['_INPUTLOGPROBS']._serialized_end=3106 - _globals['_INPUTTOKENLOGPROB']._serialized_start=3108 - _globals['_INPUTTOKENLOGPROB']._serialized_end=3157 - _globals['_TOPLOGPROBS']._serialized_start=3159 - _globals['_TOPLOGPROBS']._serialized_end=3207 - _globals['_HIDDENSTATES']._serialized_start=3209 - _globals['_HIDDENSTATES']._serialized_end=3272 - _globals['_EMBEDREQUEST']._serialized_start=3275 - _globals['_EMBEDREQUEST']._serialized_end=3605 - _globals['_EMBEDRESPONSE']._serialized_start=3608 - _globals['_EMBEDRESPONSE']._serialized_end=3765 - _globals['_EMBEDCOMPLETE']._serialized_start=3768 - _globals['_EMBEDCOMPLETE']._serialized_end=3931 - _globals['_EMBEDDING']._serialized_start=3933 - _globals['_EMBEDDING']._serialized_end=3975 - _globals['_EMBEDERROR']._serialized_start=3977 - _globals['_EMBEDERROR']._serialized_end=4037 - _globals['_HEALTHCHECKREQUEST']._serialized_start=4039 - _globals['_HEALTHCHECKREQUEST']._serialized_end=4117 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=4119 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=4174 - _globals['_ABORTREQUEST']._serialized_start=4176 - _globals['_ABORTREQUEST']._serialized_end=4226 - _globals['_ABORTRESPONSE']._serialized_start=4228 - _globals['_ABORTRESPONSE']._serialized_end=4277 - _globals['_LOADLORAREQUEST']._serialized_start=4279 - _globals['_LOADLORAREQUEST']._serialized_end=4352 - _globals['_LOADLORARESPONSE']._serialized_start=4354 - _globals['_LOADLORARESPONSE']._serialized_end=4426 - _globals['_UNLOADLORAREQUEST']._serialized_start=4428 - _globals['_UNLOADLORAREQUEST']._serialized_end=4467 - _globals['_UNLOADLORARESPONSE']._serialized_start=4469 - _globals['_UNLOADLORARESPONSE']._serialized_end=4523 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4525 - _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4644 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4646 - _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4703 - _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4705 - _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4750 - _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4752 - _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4818 - _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4820 - _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4885 - _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4887 - _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4947 - _globals['_SGLANGSCHEDULER']._serialized_start=4950 - _globals['_SGLANGSCHEDULER']._serialized_end=5332 + _globals['_GENERATEREQUEST']._serialized_end=1558 + _globals['_TOKENIZEDINPUT']._serialized_start=1560 + _globals['_TOKENIZEDINPUT']._serialized_end=1618 + _globals['_MULTIMODALINPUTS']._serialized_start=1621 + _globals['_MULTIMODALINPUTS']._serialized_end=1832 + _globals['_GENERATERESPONSE']._serialized_start=1835 + _globals['_GENERATERESPONSE']._serialized_end=2062 + _globals['_GENERATESTREAMCHUNK']._serialized_start=2065 + _globals['_GENERATESTREAMCHUNK']._serialized_end=2327 + _globals['_GENERATECOMPLETE']._serialized_start=2330 + _globals['_GENERATECOMPLETE']._serialized_end=2726 + _globals['_GENERATEERROR']._serialized_start=2728 + _globals['_GENERATEERROR']._serialized_end=2803 + _globals['_OUTPUTLOGPROBS']._serialized_start=2805 + _globals['_OUTPUTLOGPROBS']._serialized_end=2922 + _globals['_INPUTLOGPROBS']._serialized_start=2925 + _globals['_INPUTLOGPROBS']._serialized_end=3083 + _globals['_INPUTTOKENLOGPROB']._serialized_start=3085 + _globals['_INPUTTOKENLOGPROB']._serialized_end=3134 + _globals['_TOPLOGPROBS']._serialized_start=3136 + _globals['_TOPLOGPROBS']._serialized_end=3184 + _globals['_HIDDENSTATES']._serialized_start=3186 + _globals['_HIDDENSTATES']._serialized_end=3249 + _globals['_EMBEDREQUEST']._serialized_start=3252 + _globals['_EMBEDREQUEST']._serialized_end=3582 + _globals['_EMBEDRESPONSE']._serialized_start=3585 + _globals['_EMBEDRESPONSE']._serialized_end=3742 + _globals['_EMBEDCOMPLETE']._serialized_start=3745 + _globals['_EMBEDCOMPLETE']._serialized_end=3908 + _globals['_EMBEDDING']._serialized_start=3910 + _globals['_EMBEDDING']._serialized_end=3952 + _globals['_EMBEDERROR']._serialized_start=3954 + _globals['_EMBEDERROR']._serialized_end=4014 + _globals['_HEALTHCHECKREQUEST']._serialized_start=4016 + _globals['_HEALTHCHECKREQUEST']._serialized_end=4094 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=4096 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=4151 + _globals['_ABORTREQUEST']._serialized_start=4153 + _globals['_ABORTREQUEST']._serialized_end=4203 + _globals['_ABORTRESPONSE']._serialized_start=4205 + _globals['_ABORTRESPONSE']._serialized_end=4254 + _globals['_LOADLORAREQUEST']._serialized_start=4256 + _globals['_LOADLORAREQUEST']._serialized_end=4329 + _globals['_LOADLORARESPONSE']._serialized_start=4331 + _globals['_LOADLORARESPONSE']._serialized_end=4403 + _globals['_UNLOADLORAREQUEST']._serialized_start=4405 + _globals['_UNLOADLORAREQUEST']._serialized_end=4444 + _globals['_UNLOADLORARESPONSE']._serialized_start=4446 + _globals['_UNLOADLORARESPONSE']._serialized_end=4500 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4502 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4621 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4623 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4680 + _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4682 + _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4727 + _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4729 + _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4795 + _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4797 + _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4862 + _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4864 + _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4924 + _globals['_SGLANGSCHEDULER']._serialized_start=4927 + _globals['_SGLANGSCHEDULER']._serialized_end=5309 # @@protoc_insertion_point(module_scope) diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi index 096b8391d..167b68d96 100644 --- a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi @@ -82,7 +82,7 @@ class DisaggregatedParams(_message.Message): def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... class GenerateRequest(_message.Message): - __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream") + __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream") REQUEST_ID_FIELD_NUMBER: _ClassVar[int] TOKENIZED_FIELD_NUMBER: _ClassVar[int] MM_INPUTS_FIELD_NUMBER: _ClassVar[int] @@ -99,7 +99,6 @@ class GenerateRequest(_message.Message): INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int] LORA_ID_FIELD_NUMBER: _ClassVar[int] DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] - DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int] STREAM_FIELD_NUMBER: _ClassVar[int] request_id: str tokenized: TokenizedInput @@ -117,9 +116,8 @@ class GenerateRequest(_message.Message): input_embeds: _containers.RepeatedScalarFieldContainer[float] lora_id: str data_parallel_rank: int - dp_balance_id: int stream: bool - def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ... + def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ... class TokenizedInput(_message.Message): __slots__ = ("original_text", "input_ids") diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 781f95695..8c5912a0e 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -17,14 +17,11 @@ import faulthandler import logging import multiprocessing as mp import signal -import struct -import sys import threading import time from collections import deque from enum import Enum, auto -from multiprocessing import shared_memory -from typing import Dict, List +from typing import List import psutil import setproctitle @@ -39,7 +36,6 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.managers.utils import DPBalanceMeta from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( @@ -108,15 +104,9 @@ class DPBudget: class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" - def __init__( - self, - server_args: ServerArgs, - port_args: PortArgs, - dp_balance_meta: DPBalanceMeta, - ) -> None: + def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: # for dp balance self.global_balance_id = 0 - self.balance_meta = dp_balance_meta # Parse args self.max_total_num_tokens = None @@ -322,7 +312,6 @@ class DataParallelController: pp_rank, dp_rank, writer, - self.balance_meta, ), ) with memory_saver_adapter.configure_subprocess(): @@ -370,31 +359,11 @@ class DataParallelController: if self.maybe_external_dp_rank_routing(req): return - # This variable corresponds to the balance_id in TokenizedGenerateReqInput. - # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received). - def get_next_global_balance_id() -> int: - INT32_MAX = 2147483647 - current_id = self.global_balance_id - self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX - return current_id - - req.dp_balance_id = get_next_global_balance_id() - with self.balance_meta.mutex: - # 1. local_tokens represents the tokens currently inferring on the worker, - # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler. - onfly_info = self.balance_meta.get_shared_onfly() - local_tokens = self.balance_meta.get_shared_local_tokens() - total_tokens = [ - local_token + sum(onfly_dict.values()) - for local_token, onfly_dict in zip(local_tokens, onfly_info) - ] - target_worker = total_tokens.index(min(total_tokens)) - onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids) - # 2. write the new onfly info to the shm - self.balance_meta.set_shared_onfly_info(onfly_info) - - # logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") - self.workers[target_worker].send_pyobj(req) + logger.warning( + "The 'minimum_tokens' load balancing method is deprecated for now and will introduced later." + "Fall back to 'round_robin_scheduler'" + ) + self.round_robin_scheduler(req) def event_loop(self): while True: @@ -416,12 +385,9 @@ def run_data_parallel_controller_process( faulthandler.enable() configure_logger(server_args) parent_process = psutil.Process().parent() - balance_meta = DPBalanceMeta(server_args.dp_size) try: - controller = DataParallelController( - server_args, port_args, dp_balance_meta=balance_meta - ) + controller = DataParallelController(server_args, port_args) pipe_writer.send( { "status": "ready", @@ -440,6 +406,3 @@ def run_data_parallel_controller_process( traceback = get_exception_traceback() logger.error(f"DataParallelController hit an exception: {traceback}") parent_process.send_signal(signal.SIGQUIT) - finally: - # we need to destruct mp.Manager() in balance_meta - balance_meta.destructor() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 436d62f27..791c39399 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -606,9 +606,6 @@ class TokenizedGenerateReqInput: # For data parallel rank routing data_parallel_rank: Optional[int] = None - # For dp balance - dp_balance_id: int = -1 - # Priority for the request priority: Optional[int] = None @@ -778,8 +775,6 @@ class TokenizedEmbeddingReqInput: sampling_params: SamplingParams # For data parallel rank routing data_parallel_rank: Optional[int] = None - # For dp balance - dp_balance_id: int = -1 # Priority for the request priority: Optional[int] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8fb74093c..36867abf3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -145,7 +145,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import ( from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient -from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length +from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache @@ -271,7 +271,6 @@ class Scheduler( moe_ep_rank: int, pp_rank: int, dp_rank: Optional[int], - dp_balance_meta: Optional[DPBalanceMeta] = None, ): # Parse args self.server_args = server_args @@ -600,7 +599,6 @@ class Scheduler( # Init metrics stats self.init_metrics(tp_rank, pp_rank, dp_rank) - self.init_dp_balance(dp_balance_meta) if self.enable_kv_cache_events: self.init_kv_events(server_args.kv_events_config) @@ -1270,8 +1268,6 @@ class Scheduler( self, recv_req: TokenizedGenerateReqInput, ): - self.maybe_update_dp_balance_data(recv_req) - # Create a new request if ( recv_req.session_params is None @@ -1797,7 +1793,6 @@ class Scheduler( # Handle DP attention if need_dp_attn_preparation: - self.maybe_handle_dp_balance_data() ret = self.prepare_mlp_sync_batch(ret) return ret @@ -2803,7 +2798,6 @@ def run_scheduler_process( pp_rank: int, dp_rank: Optional[int], pipe_writer, - balance_meta: Optional[DPBalanceMeta] = None, ): # Generate the logger prefix prefix = "" @@ -2852,7 +2846,6 @@ def run_scheduler_process( moe_ep_rank, pp_rank, dp_rank, - dp_balance_meta=balance_meta, ) pipe_writer.send( { diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 7e1154dc2..2af5ab5ab 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.schedule_policy import PrefillAdder from sglang.srt.managers.scheduler import Req, ScheduleBatch -from sglang.srt.managers.utils import DPBalanceMeta from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.utils import get_bool_env_var @@ -64,16 +63,6 @@ class SchedulerMetricsMixin: labels["dp_rank"] = dp_rank self.metrics_collector = SchedulerMetricsCollector(labels=labels) - def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]): - self.balance_meta = dp_balance_meta - if ( - self.server_args.enable_dp_attention - and self.server_args.load_balance_method == "minimum_tokens" - ): - assert dp_balance_meta is not None - - self.recv_dp_balance_id_this_term = [] - def init_kv_events(self: Scheduler, kv_events_config: Optional[str]): if self.enable_kv_cache_events: self.kv_event_publisher = EventPublisherFactory.create( @@ -319,91 +308,6 @@ class SchedulerMetricsMixin: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) - def maybe_update_dp_balance_data( - self: Scheduler, recv_req: TokenizedGenerateReqInput - ): - if ( - self.server_args.enable_dp_attention - and self.server_args.load_balance_method == "minimum_tokens" - ): - self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id) - - def maybe_handle_dp_balance_data(self: Scheduler): - if ( - self.server_args.load_balance_method == "minimum_tokens" - and self.forward_ct % 40 == 0 - ): - holding_tokens = self.get_load().num_tokens - - new_recv_dp_balance_id_list, holding_token_list = ( - self.gather_dp_balance_info(holding_tokens) - ) - - self.recv_dp_balance_id_this_term.clear() - if self.tp_rank == 0: # only first worker write info - self.write_shared_dp_balance_info( - new_recv_dp_balance_id_list, holding_token_list - ) - - def gather_dp_balance_info( - self: Scheduler, holding_tokens_list - ) -> Union[None, List[List[int]]]: - """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" - recv_list = self.recv_dp_balance_id_this_term - assert len(recv_list) <= 511, ( - "The number of requests received this round is too large. " - "Please increase gather_tensor_size and onfly_info_size." - ) - # The maximum size of the tensor used for gathering data from all workers. - gather_tensor_size = 512 - - # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids - recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) - recv_tensor[0] = holding_tokens_list - recv_tensor[1] = len(recv_list) # The first element is the length of the list. - recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32) - - if self.tp_rank == 0: - gathered_list = [ - torch.zeros(gather_tensor_size, dtype=torch.int32) - for _ in range(self.balance_meta.num_workers) - ] - else: - gathered_list = None - - torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group) - - gathered_id_list_per_worker = None - if self.tp_rank == 0: - gathered_id_list_per_worker = [] - holding_tokens_list = [] - for tensor in gathered_list: - holding_tokens_list.append(tensor[0].item()) - list_length = tensor[1].item() - gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist()) - - return gathered_id_list_per_worker, holding_tokens_list - - def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens): - meta = self.balance_meta - - with meta.mutex: - onfly_list: List[Dict[int, int]] = meta.get_shared_onfly() - assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal" - # 1.Check if the rid received by each worker this round is present in onfly. - # If it is, remove the corresponding onfly item. - worker_id = 0 - for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list): - for new_recv_rid in new_recv_rids: - assert ( - new_recv_rid in on_fly_reqs - ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong" - del on_fly_reqs[new_recv_rid] - worker_id += 1 - # 2. Atomically write local_tokens and onfly into shm under the mutex - meta.set_shared_onfly_info(onfly_list) - meta.set_shared_local_tokens(local_tokens) - def calculate_utilization(self): if self.disaggregation_mode == DisaggregationMode.PREFILL: self.stats.utilization = -1 diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index e25f2628f..ccd3f0fe2 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -96,46 +96,3 @@ def get_logprob_from_pp_outputs( ] return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req - - -class DPBalanceMeta: - """ - This class will be use in scheduler and dp controller - """ - - def __init__(self, num_workers: int): - self.num_workers = num_workers - self._manager = mp.Manager() - self.mutex = self._manager.Lock() - - init_local_tokens = [0] * self.num_workers - init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] - - self.shared_state = self._manager.Namespace() - self.shared_state.local_tokens = self._manager.list(init_local_tokens) - self.shared_state.onfly_info = self._manager.list(init_onfly_info) - - def destructor(self): - # we must destructor this class manually - self._manager.shutdown() - - def get_shared_onfly(self) -> List[Dict[int, int]]: - return [dict(d) for d in self.shared_state.onfly_info] - - def set_shared_onfly_info(self, data: List[Dict[int, int]]): - self.shared_state.onfly_info = data - - def get_shared_local_tokens(self) -> List[int]: - return list(self.shared_state.local_tokens) - - def set_shared_local_tokens(self, data: List[int]): - self.shared_state.local_tokens = data - - def __getstate__(self): - state = self.__dict__.copy() - del state["_manager"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._manager = None diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index ee0135b1f..1d52e54c5 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -120,11 +120,8 @@ message GenerateRequest { // Data parallel routing int32 data_parallel_rank = 16; - // For load balancing - int32 dp_balance_id = 17; - // Whether client wants streaming response - bool stream = 18; + bool stream = 17; } message TokenizedInput { diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 4486dc16e..426616456 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -124,47 +124,5 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase): self.assertGreater(avg_spec_accept_length, 2.5) -class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "2", - "--enable-dp-attention", - "--dp", - "2", - "--enable-torch-compile", - "--torch-compile-max-bs", - "2", - "--load-balance-method", - "minimum_tokens", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) - - if __name__ == "__main__": unittest.main()