router: Support parallel sampling num > 1 in grpc_server and non-stream handling (#10929)

This commit is contained in:
Chang Su
2025-09-25 20:03:35 -07:00
committed by GitHub
parent 3e95aa1a09
commit 37158f2018
8 changed files with 281 additions and 135 deletions

View File

@@ -103,6 +103,7 @@ impl SglangSchedulerClient {
logprob_start_len: -1,
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
return_hidden_states: body.return_hidden_states,
stream: body.stream,
..Default::default()
};
@@ -367,14 +368,14 @@ mod tests {
#[test]
fn test_generate_stream_chunk() {
let chunk = proto::GenerateStreamChunk {
token_id: 1234,
token_ids: vec![1234, 5678],
prompt_tokens: 5,
completion_tokens: 2,
cached_tokens: 3,
..Default::default()
};
assert_eq!(chunk.token_id, 1234);
assert_eq!(chunk.token_ids, vec![1234, 5678]);
assert_eq!(chunk.prompt_tokens, 5);
assert_eq!(chunk.completion_tokens, 2);
assert_eq!(chunk.cached_tokens, 3);

View File

@@ -122,6 +122,9 @@ message GenerateRequest {
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
}
message TokenizedInput {
@@ -163,8 +166,8 @@ message GenerateResponse {
}
message GenerateStreamChunk {
// Generated token
int32 token_id = 1;
// Generated tokens (incremental chunk)
repeated int32 token_ids = 1;
// Cumulative counts
int32 prompt_tokens = 2;

View File

@@ -203,6 +203,7 @@ impl GrpcRouter {
debug!("Selected worker: {}", worker.url());
// Step 2: Get gRPC client for worker (fail fast if can't connect)
// TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here)
let client = match self.get_or_create_grpc_client(worker.url()).await {
Ok(c) => c,
Err(e) => {
@@ -249,7 +250,7 @@ impl GrpcRouter {
// Step 6: Build the base gRPC request
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let base_request = match client.build_generate_request(
let request = match client.build_generate_request(
request_id,
body,
processed_messages.text.clone(),
@@ -268,11 +269,11 @@ impl GrpcRouter {
}
};
// Step 7: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_chat(client, base_request, body).await
self.handle_streaming_chat(client, request, body).await
} else {
self.handle_non_streaming_chat(client, base_request, body)
.await
self.handle_non_streaming_chat(client, request, body).await
}
}