[router][grpc] Fix proto3 default value mismatches and cleanup unused fields (#11283)
This commit is contained in:
@@ -14,6 +14,7 @@ from concurrent import futures
|
||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
import grpc
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||
@@ -483,28 +484,52 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
elif grpc_params.HasField("structural_tag"):
|
||||
structural_tag = grpc_params.structural_tag
|
||||
|
||||
# Handle optional parameters conversion
|
||||
custom_params = (
|
||||
MessageToDict(grpc_params.custom_params)
|
||||
if grpc_params.HasField("custom_params")
|
||||
else None
|
||||
)
|
||||
max_new_tokens = (
|
||||
grpc_params.max_new_tokens
|
||||
if grpc_params.HasField("max_new_tokens")
|
||||
else None
|
||||
)
|
||||
stream_interval = (
|
||||
grpc_params.stream_interval
|
||||
if grpc_params.HasField("stream_interval")
|
||||
else None
|
||||
)
|
||||
logit_bias = dict(grpc_params.logit_bias) if grpc_params.logit_bias else None
|
||||
stop = list(grpc_params.stop) if grpc_params.stop else None
|
||||
stop_token_ids = (
|
||||
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
|
||||
)
|
||||
|
||||
return SGLSamplingParams(
|
||||
temperature=grpc_params.temperature or 1.0,
|
||||
top_p=grpc_params.top_p or 1.0,
|
||||
top_k=grpc_params.top_k or -1,
|
||||
min_p=grpc_params.min_p or 0.0,
|
||||
frequency_penalty=grpc_params.frequency_penalty or 0.0,
|
||||
presence_penalty=grpc_params.presence_penalty or 0.0,
|
||||
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
||||
max_new_tokens=grpc_params.max_new_tokens or 128,
|
||||
min_new_tokens=grpc_params.min_new_tokens or 0,
|
||||
stop=list(grpc_params.stop) if grpc_params.stop else [],
|
||||
stop_token_ids=(
|
||||
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
|
||||
),
|
||||
temperature=grpc_params.temperature,
|
||||
top_p=grpc_params.top_p,
|
||||
top_k=grpc_params.top_k,
|
||||
min_p=grpc_params.min_p,
|
||||
frequency_penalty=grpc_params.frequency_penalty,
|
||||
presence_penalty=grpc_params.presence_penalty,
|
||||
repetition_penalty=grpc_params.repetition_penalty,
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=grpc_params.min_new_tokens,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
skip_special_tokens=grpc_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
||||
no_stop_trim=grpc_params.no_stop_trim,
|
||||
regex=regex,
|
||||
json_schema=json_schema,
|
||||
ebnf=ebnf_grammar,
|
||||
structural_tag=structural_tag,
|
||||
n=grpc_params.n or 1,
|
||||
n=grpc_params.n,
|
||||
ignore_eos=grpc_params.ignore_eos,
|
||||
stream_interval=stream_interval,
|
||||
logit_bias=logit_bias,
|
||||
custom_params=custom_params,
|
||||
)
|
||||
|
||||
def _convert_output_logprobs_to_proto(
|
||||
|
||||
@@ -27,6 +27,11 @@ service SglangScheduler {
|
||||
// =====================
|
||||
|
||||
// Sampling parameters matching SGLang's SamplingParams
|
||||
//
|
||||
// IMPORTANT: Do not use SamplingParams::default() directly!
|
||||
// The proto3 defaults (0 for numeric fields) do NOT match the semantic defaults
|
||||
// (temperature=1.0, top_p=1.0, top_k=-1, etc.). Always construct with explicit values
|
||||
// or use the conversion functions in sglang_scheduler.rs / grpc_server.py.
|
||||
message SamplingParams {
|
||||
float temperature = 1;
|
||||
float top_p = 2;
|
||||
@@ -50,24 +55,18 @@ message SamplingParams {
|
||||
string structural_tag = 16;
|
||||
}
|
||||
|
||||
// LoRA adapter
|
||||
string lora_path = 17;
|
||||
|
||||
// Speculative decoding
|
||||
int32 n = 18; // Number of samples
|
||||
|
||||
// Token healing
|
||||
bool token_healing = 19;
|
||||
int32 n = 17; // Number of samples
|
||||
|
||||
// Additional parameters
|
||||
int32 min_new_tokens = 20;
|
||||
bool ignore_eos = 21;
|
||||
bool no_stop_trim = 22;
|
||||
int32 stream_interval = 23;
|
||||
map<string, float> logit_bias = 24;
|
||||
int32 min_new_tokens = 18;
|
||||
bool ignore_eos = 19;
|
||||
bool no_stop_trim = 20;
|
||||
optional int32 stream_interval = 21;
|
||||
map<string, float> logit_bias = 22;
|
||||
|
||||
// Custom parameters for extensibility
|
||||
google.protobuf.Struct custom_params = 25;
|
||||
google.protobuf.Struct custom_params = 23;
|
||||
}
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -11,7 +11,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class SamplingParams(_message.Message):
|
||||
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
|
||||
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "n", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
|
||||
class LogitBiasEntry(_message.Message):
|
||||
__slots__ = ("key", "value")
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -35,9 +35,7 @@ class SamplingParams(_message.Message):
|
||||
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
|
||||
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
|
||||
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
||||
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||
N_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
|
||||
MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
|
||||
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -60,16 +58,14 @@ class SamplingParams(_message.Message):
|
||||
json_schema: str
|
||||
ebnf_grammar: str
|
||||
structural_tag: str
|
||||
lora_path: str
|
||||
n: int
|
||||
token_healing: bool
|
||||
min_new_tokens: int
|
||||
ignore_eos: bool
|
||||
no_stop_trim: bool
|
||||
stream_interval: int
|
||||
logit_bias: _containers.ScalarMap[str, float]
|
||||
custom_params: _struct_pb2.Struct
|
||||
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., n: _Optional[int] = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class DisaggregatedParams(_message.Message):
|
||||
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
|
||||
|
||||
Reference in New Issue
Block a user