router: Support parallel sampling num > 1 in grpc_server and non-stream handling (#10929)
This commit is contained in:
@@ -103,6 +103,7 @@ impl SglangSchedulerClient {
|
||||
logprob_start_len: -1,
|
||||
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
|
||||
return_hidden_states: body.return_hidden_states,
|
||||
stream: body.stream,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -367,14 +368,14 @@ mod tests {
|
||||
#[test]
|
||||
fn test_generate_stream_chunk() {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
token_id: 1234,
|
||||
token_ids: vec![1234, 5678],
|
||||
prompt_tokens: 5,
|
||||
completion_tokens: 2,
|
||||
cached_tokens: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(chunk.token_id, 1234);
|
||||
assert_eq!(chunk.token_ids, vec![1234, 5678]);
|
||||
assert_eq!(chunk.prompt_tokens, 5);
|
||||
assert_eq!(chunk.completion_tokens, 2);
|
||||
assert_eq!(chunk.cached_tokens, 3);
|
||||
|
||||
@@ -122,6 +122,9 @@ message GenerateRequest {
|
||||
|
||||
// For load balancing
|
||||
int32 dp_balance_id = 17;
|
||||
|
||||
// Whether client wants streaming response
|
||||
bool stream = 18;
|
||||
}
|
||||
|
||||
message TokenizedInput {
|
||||
@@ -163,8 +166,8 @@ message GenerateResponse {
|
||||
}
|
||||
|
||||
message GenerateStreamChunk {
|
||||
// Generated token
|
||||
int32 token_id = 1;
|
||||
// Generated tokens (incremental chunk)
|
||||
repeated int32 token_ids = 1;
|
||||
|
||||
// Cumulative counts
|
||||
int32 prompt_tokens = 2;
|
||||
|
||||
@@ -203,6 +203,7 @@ impl GrpcRouter {
|
||||
debug!("Selected worker: {}", worker.url());
|
||||
|
||||
// Step 2: Get gRPC client for worker (fail fast if can't connect)
|
||||
// TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here)
|
||||
let client = match self.get_or_create_grpc_client(worker.url()).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
@@ -249,7 +250,7 @@ impl GrpcRouter {
|
||||
|
||||
// Step 6: Build the base gRPC request
|
||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let base_request = match client.build_generate_request(
|
||||
let request = match client.build_generate_request(
|
||||
request_id,
|
||||
body,
|
||||
processed_messages.text.clone(),
|
||||
@@ -268,11 +269,11 @@ impl GrpcRouter {
|
||||
}
|
||||
};
|
||||
|
||||
// Step 7: Handle streaming vs non-streaming
|
||||
if body.stream {
|
||||
self.handle_streaming_chat(client, base_request, body).await
|
||||
self.handle_streaming_chat(client, request, body).await
|
||||
} else {
|
||||
self.handle_non_streaming_chat(client, base_request, body)
|
||||
.await
|
||||
self.handle_non_streaming_chat(client, request, body).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user