[router] grpc router generate endpoint support (#11070)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -1,8 +1,12 @@
|
|||||||
|
use std::convert::TryFrom;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::{transport::Channel, Request};
|
use tonic::{transport::Channel, Request};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat};
|
use crate::protocols::spec::{
|
||||||
|
ChatCompletionRequest, GenerateRequest, ResponseFormat,
|
||||||
|
SamplingParams as GenerateSamplingParams, StringOrArray,
|
||||||
|
};
|
||||||
|
|
||||||
// Include the generated protobuf code
|
// Include the generated protobuf code
|
||||||
pub mod proto {
|
pub mod proto {
|
||||||
@@ -112,6 +116,37 @@ impl SglangSchedulerClient {
|
|||||||
Ok(grpc_request)
|
Ok(grpc_request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a basic GenerateRequest from the SGLang spec GenerateRequest
|
||||||
|
pub fn build_plain_generate_request(
|
||||||
|
&self,
|
||||||
|
request_id: String,
|
||||||
|
body: &GenerateRequest,
|
||||||
|
original_text: Option<String>,
|
||||||
|
token_ids: Vec<u32>,
|
||||||
|
) -> Result<proto::GenerateRequest, String> {
|
||||||
|
let sampling_params =
|
||||||
|
Self::build_sampling_params_from_plain(body.sampling_params.as_ref())?;
|
||||||
|
|
||||||
|
let grpc_request = proto::GenerateRequest {
|
||||||
|
request_id,
|
||||||
|
tokenized: Some(proto::TokenizedInput {
|
||||||
|
original_text: original_text.unwrap_or_default(),
|
||||||
|
input_ids: token_ids,
|
||||||
|
}),
|
||||||
|
sampling_params: Some(sampling_params),
|
||||||
|
return_logprob: body.return_logprob,
|
||||||
|
logprob_start_len: -1,
|
||||||
|
top_logprobs_num: 0,
|
||||||
|
token_ids_logprob: vec![],
|
||||||
|
return_hidden_states: body.return_hidden_states,
|
||||||
|
stream: body.stream,
|
||||||
|
log_metrics: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(grpc_request)
|
||||||
|
}
|
||||||
|
|
||||||
/// Build gRPC SamplingParams from OpenAI request
|
/// Build gRPC SamplingParams from OpenAI request
|
||||||
fn build_grpc_sampling_params(
|
fn build_grpc_sampling_params(
|
||||||
&self,
|
&self,
|
||||||
@@ -165,8 +200,8 @@ impl SglangSchedulerClient {
|
|||||||
/// Extract stop strings from request
|
/// Extract stop strings from request
|
||||||
fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> {
|
fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> {
|
||||||
match &request.stop {
|
match &request.stop {
|
||||||
Some(crate::protocols::spec::StringOrArray::String(s)) => vec![s.clone()],
|
Some(StringOrArray::String(s)) => vec![s.clone()],
|
||||||
Some(crate::protocols::spec::StringOrArray::Array(arr)) => arr.clone(),
|
Some(StringOrArray::Array(arr)) => arr.clone(),
|
||||||
None => vec![],
|
None => vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -218,6 +253,100 @@ impl SglangSchedulerClient {
|
|||||||
_ => Err("Multiple constraints are not allowed.".to_string()),
|
_ => Err("Multiple constraints are not allowed.".to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_single_constraint_from_plain(
|
||||||
|
params: &GenerateSamplingParams,
|
||||||
|
) -> Result<Option<proto::sampling_params::Constraint>, String> {
|
||||||
|
let mut constraints = Vec::new();
|
||||||
|
if let Some(json_schema) = ¶ms.json_schema {
|
||||||
|
constraints.push(proto::sampling_params::Constraint::JsonSchema(
|
||||||
|
json_schema.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if let Some(regex) = ¶ms.regex {
|
||||||
|
constraints.push(proto::sampling_params::Constraint::Regex(regex.clone()));
|
||||||
|
}
|
||||||
|
if let Some(ebnf) = ¶ms.ebnf {
|
||||||
|
constraints.push(proto::sampling_params::Constraint::EbnfGrammar(
|
||||||
|
ebnf.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
match constraints.len() {
|
||||||
|
0 => Ok(None),
|
||||||
|
1 => Ok(constraints.pop()),
|
||||||
|
_ => Err("Multiple structured constraints are not allowed".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_sampling_params_from_plain(
|
||||||
|
params: Option<&GenerateSamplingParams>,
|
||||||
|
) -> Result<proto::SamplingParams, String> {
|
||||||
|
let mut sampling = proto::SamplingParams {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_p: 1.0,
|
||||||
|
top_k: -1,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
n: 1,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(p) = params else {
|
||||||
|
return Ok(sampling);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Simple field mappings using a macro
|
||||||
|
macro_rules! map_field {
|
||||||
|
($field:ident) => {
|
||||||
|
if let Some(val) = p.$field {
|
||||||
|
sampling.$field = val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
map_field!(temperature);
|
||||||
|
map_field!(top_p);
|
||||||
|
map_field!(top_k);
|
||||||
|
map_field!(frequency_penalty);
|
||||||
|
map_field!(presence_penalty);
|
||||||
|
map_field!(repetition_penalty);
|
||||||
|
map_field!(min_p);
|
||||||
|
map_field!(ignore_eos);
|
||||||
|
map_field!(skip_special_tokens);
|
||||||
|
map_field!(no_stop_trim);
|
||||||
|
|
||||||
|
// Handle stop sequences
|
||||||
|
if let Some(stop) = &p.stop {
|
||||||
|
match stop {
|
||||||
|
StringOrArray::String(s) => sampling.stop.push(s.clone()),
|
||||||
|
StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle stop token IDs
|
||||||
|
if let Some(stop_token_ids) = &p.stop_token_ids {
|
||||||
|
sampling.stop_token_ids = stop_token_ids.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle max_new_tokens with conversion
|
||||||
|
if let Some(max_new_tokens) = p.max_new_tokens {
|
||||||
|
sampling.max_new_tokens =
|
||||||
|
Some(i32::try_from(max_new_tokens).map_err(|_| {
|
||||||
|
"max_new_tokens must fit into a 32-bit signed integer".to_string()
|
||||||
|
})?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle min_tokens with conversion
|
||||||
|
if let Some(min_tokens) = p.min_tokens {
|
||||||
|
sampling.min_new_tokens = i32::try_from(min_tokens)
|
||||||
|
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle constraints (exactly one allowed)
|
||||||
|
sampling.constraint = Self::build_single_constraint_from_plain(p)?;
|
||||||
|
|
||||||
|
Ok(sampling)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -27,12 +27,15 @@ use crate::reasoning_parser::ParserFactory;
|
|||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
use crate::server::AppContext;
|
use crate::server::AppContext;
|
||||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||||
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoderBuilder};
|
use crate::tokenizer::stop::{
|
||||||
|
SequenceDecoderOutput, StopSequenceDecoder, StopSequenceDecoderBuilder,
|
||||||
|
};
|
||||||
use crate::tokenizer::traits::Tokenizer;
|
use crate::tokenizer::traits::Tokenizer;
|
||||||
use crate::tokenizer::HuggingFaceTokenizer;
|
use crate::tokenizer::HuggingFaceTokenizer;
|
||||||
use crate::tool_parser::ParserRegistry;
|
use crate::tool_parser::ParserRegistry;
|
||||||
use serde_json::Value;
|
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use serde_json::{json, Value};
|
||||||
|
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -124,28 +127,9 @@ impl GrpcRouter {
|
|||||||
debug!("Selected worker: {}", worker.url());
|
debug!("Selected worker: {}", worker.url());
|
||||||
|
|
||||||
// Step 2: Get gRPC client from worker
|
// Step 2: Get gRPC client from worker
|
||||||
let client = match worker.get_grpc_client().await {
|
let client = match Self::get_grpc_client_from_worker(&worker).await {
|
||||||
Ok(Some(client_arc)) => {
|
Ok(client) => client,
|
||||||
// Clone the client from inside the Arc<Mutex<>>
|
Err(response) => return response,
|
||||||
let client = client_arc.lock().await.clone();
|
|
||||||
client
|
|
||||||
}
|
|
||||||
Ok(None) => {
|
|
||||||
error!("Selected worker is not a gRPC worker");
|
|
||||||
return (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Selected worker is not configured for gRPC",
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to get gRPC client from worker: {}", e);
|
|
||||||
return (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to get gRPC client: {}", e),
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Step 3: Process messages and apply chat template
|
// Step 3: Process messages and apply chat template
|
||||||
@@ -209,6 +193,112 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Main route_generate implementation
|
||||||
|
async fn route_generate_impl(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
body: &GenerateRequest,
|
||||||
|
model_id: Option<&str>,
|
||||||
|
) -> Response {
|
||||||
|
debug!("Processing generate request for model: {:?}", model_id);
|
||||||
|
|
||||||
|
// Step 1: Resolve input (text, prompt, or input_ids)
|
||||||
|
let (original_text, token_ids) = match self.resolve_generate_input(body) {
|
||||||
|
Ok(res) => res,
|
||||||
|
Err(msg) => {
|
||||||
|
error!("Invalid generate request: {}", msg);
|
||||||
|
return (StatusCode::BAD_REQUEST, msg).into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("Resolved input with {} tokens", token_ids.len());
|
||||||
|
|
||||||
|
// Step 2: Select worker (fail fast if no workers available)
|
||||||
|
let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) {
|
||||||
|
Some(w) => w,
|
||||||
|
None => {
|
||||||
|
warn!("No available workers for model: {:?}", model_id);
|
||||||
|
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("Selected worker: {}", worker.url());
|
||||||
|
|
||||||
|
// Step 3: Get gRPC client from worker
|
||||||
|
let client = match Self::get_grpc_client_from_worker(&worker).await {
|
||||||
|
Ok(client) => client,
|
||||||
|
Err(response) => return response,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 4: Build the gRPC request
|
||||||
|
let request_id = body
|
||||||
|
.rid
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
|
||||||
|
|
||||||
|
let request = match client.build_plain_generate_request(
|
||||||
|
request_id.clone(),
|
||||||
|
body,
|
||||||
|
original_text.clone(),
|
||||||
|
token_ids,
|
||||||
|
) {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to build generate request: {}", e);
|
||||||
|
return (StatusCode::BAD_REQUEST, e).into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 5: Get weight version for response metadata
|
||||||
|
let weight_version = worker
|
||||||
|
.metadata()
|
||||||
|
.labels
|
||||||
|
.get("weight_version")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "default".to_string());
|
||||||
|
|
||||||
|
// Step 6: Handle streaming vs non-streaming
|
||||||
|
if body.stream {
|
||||||
|
// TODO: Implement streaming support for generate endpoint
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Streaming generate over gRPC is not supported yet",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
self.handle_non_streaming_generate(client, request, body, request_id, weight_version)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get gRPC client from worker, returning appropriate error response on failure
|
||||||
|
async fn get_grpc_client_from_worker(
|
||||||
|
worker: &Arc<dyn Worker>,
|
||||||
|
) -> Result<SglangSchedulerClient, Response> {
|
||||||
|
let client_arc = worker
|
||||||
|
.get_grpc_client()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("Failed to get gRPC client from worker: {}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to get gRPC client: {}", e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
})?
|
||||||
|
.ok_or_else(|| {
|
||||||
|
error!("Selected worker is not a gRPC worker");
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"Selected worker is not configured for gRPC",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let client = client_arc.lock().await.clone();
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
/// Select a worker for the request
|
/// Select a worker for the request
|
||||||
fn select_worker_for_request(
|
fn select_worker_for_request(
|
||||||
&self,
|
&self,
|
||||||
@@ -265,7 +355,7 @@ impl GrpcRouter {
|
|||||||
Self::process_tool_call_arguments(&mut transformed_messages)?;
|
Self::process_tool_call_arguments(&mut transformed_messages)?;
|
||||||
|
|
||||||
// Convert tools to JSON values for template processing
|
// Convert tools to JSON values for template processing
|
||||||
let tools_json: Option<Vec<serde_json::Value>> = request
|
let tools_json: Option<Vec<Value>> = request
|
||||||
.tools
|
.tools
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|tools| {
|
.map(|tools| {
|
||||||
@@ -284,7 +374,7 @@ impl GrpcRouter {
|
|||||||
if let Some(reasoning_effort) = &request.reasoning_effort {
|
if let Some(reasoning_effort) = &request.reasoning_effort {
|
||||||
combined_template_kwargs.insert(
|
combined_template_kwargs.insert(
|
||||||
"reasoning_effort".to_string(),
|
"reasoning_effort".to_string(),
|
||||||
serde_json::Value::String(reasoning_effort.clone()),
|
Value::String(reasoning_effort.clone()),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -413,9 +503,9 @@ impl GrpcRouter {
|
|||||||
part.as_object()
|
part.as_object()
|
||||||
.and_then(|obj| obj.get("type")?.as_str())
|
.and_then(|obj| obj.get("type")?.as_str())
|
||||||
.and_then(|type_str| match type_str {
|
.and_then(|type_str| match type_str {
|
||||||
"image_url" => Some(serde_json::json!({"type": "image"})),
|
"image_url" => Some(json!({"type": "image"})),
|
||||||
"video_url" => Some(serde_json::json!({"type": "video"})),
|
"video_url" => Some(json!({"type": "video"})),
|
||||||
"audio_url" => Some(serde_json::json!({"type": "audio"})),
|
"audio_url" => Some(json!({"type": "audio"})),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|| part.clone())
|
.unwrap_or_else(|| part.clone())
|
||||||
@@ -456,7 +546,7 @@ impl GrpcRouter {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Parse JSON string to object (like Python json.loads)
|
// Parse JSON string to object (like Python json.loads)
|
||||||
match serde_json::from_str::<serde_json::Value>(args_str) {
|
match serde_json::from_str::<Value>(args_str) {
|
||||||
Ok(parsed) => *args = parsed,
|
Ok(parsed) => *args = parsed,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
@@ -483,13 +573,63 @@ impl GrpcRouter {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a StopSequenceDecoder from the chat completion request
|
/// Resolve the generate input into optional original text and token IDs
|
||||||
|
fn resolve_generate_input(
|
||||||
|
&self,
|
||||||
|
request: &GenerateRequest,
|
||||||
|
) -> Result<(Option<String>, Vec<u32>), String> {
|
||||||
|
if let Some(text) = &request.text {
|
||||||
|
return self
|
||||||
|
.tokenize_single_text(text)
|
||||||
|
.map(|(original, ids)| (Some(original), ids));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle input_ids - validate and convert
|
||||||
|
if let Some(input_ids) = &request.input_ids {
|
||||||
|
return match input_ids {
|
||||||
|
crate::protocols::spec::InputIds::Single(ids) => ids
|
||||||
|
.iter()
|
||||||
|
.map(|&id| u32::try_from(id))
|
||||||
|
.collect::<Result<Vec<u32>, _>>()
|
||||||
|
.map(|converted| (None, converted))
|
||||||
|
.map_err(|_| "input_ids must be non-negative".to_string()),
|
||||||
|
crate::protocols::spec::InputIds::Batch(_) => {
|
||||||
|
Err("Batch input_ids are not supported over gRPC generate yet".to_string())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Err("Either `text` or `input_ids` must be provided".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tokenize_single_text(&self, text: &str) -> Result<(String, Vec<u32>), String> {
|
||||||
|
let encoding = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(text)
|
||||||
|
.map_err(|e| format!("Tokenization failed: {}", e))?;
|
||||||
|
Ok((text.to_string(), encoding.token_ids().to_vec()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn internal_error_static(msg: &'static str) -> Response {
|
||||||
|
error!("{}", msg);
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn internal_error_message(message: String) -> Response {
|
||||||
|
error!("{}", message);
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a StopSequenceDecoder from stop parameters
|
||||||
fn create_stop_decoder(
|
fn create_stop_decoder(
|
||||||
&self,
|
&self,
|
||||||
original_request: &ChatCompletionRequest,
|
stop: Option<&StringOrArray>,
|
||||||
) -> crate::tokenizer::stop::StopSequenceDecoder {
|
stop_token_ids: Option<&Vec<u32>>,
|
||||||
// Extract stop sequences from request
|
skip_special_tokens: bool,
|
||||||
let stop_sequences: Vec<String> = match &original_request.stop {
|
no_stop_trim: bool,
|
||||||
|
) -> StopSequenceDecoder {
|
||||||
|
// Extract stop sequences
|
||||||
|
let stop_sequences: Vec<String> = match stop {
|
||||||
Some(StringOrArray::String(s)) => vec![s.clone()],
|
Some(StringOrArray::String(s)) => vec![s.clone()],
|
||||||
Some(StringOrArray::Array(arr)) => arr.clone(),
|
Some(StringOrArray::Array(arr)) => arr.clone(),
|
||||||
None => vec![],
|
None => vec![],
|
||||||
@@ -497,11 +637,11 @@ impl GrpcRouter {
|
|||||||
|
|
||||||
// Build stop sequence decoder
|
// Build stop sequence decoder
|
||||||
let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone())
|
let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone())
|
||||||
.skip_special_tokens(original_request.skip_special_tokens);
|
.skip_special_tokens(skip_special_tokens);
|
||||||
|
|
||||||
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
|
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
|
||||||
for seq in stop_sequences {
|
for seq in stop_sequences {
|
||||||
builder = if original_request.no_stop_trim {
|
builder = if no_stop_trim {
|
||||||
builder.visible_stop_sequence(seq)
|
builder.visible_stop_sequence(seq)
|
||||||
} else {
|
} else {
|
||||||
builder.stop_sequence(seq)
|
builder.stop_sequence(seq)
|
||||||
@@ -509,9 +649,9 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
|
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
|
||||||
if let Some(stop_token_ids) = &original_request.stop_token_ids {
|
if let Some(token_ids) = stop_token_ids {
|
||||||
for &token_id in stop_token_ids {
|
for &token_id in token_ids {
|
||||||
builder = if original_request.no_stop_trim {
|
builder = if no_stop_trim {
|
||||||
builder.visible_stop_token(token_id)
|
builder.visible_stop_token(token_id)
|
||||||
} else {
|
} else {
|
||||||
builder.stop_token(token_id)
|
builder.stop_token(token_id)
|
||||||
@@ -524,7 +664,7 @@ impl GrpcRouter {
|
|||||||
|
|
||||||
/// Process a chunk of tokens through the stop decoder
|
/// Process a chunk of tokens through the stop decoder
|
||||||
fn process_chunk_tokens(
|
fn process_chunk_tokens(
|
||||||
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
|
stop_decoder: &mut StopSequenceDecoder,
|
||||||
token_ids: &[u32],
|
token_ids: &[u32],
|
||||||
) -> (String, bool) {
|
) -> (String, bool) {
|
||||||
let mut chunk_text = String::new();
|
let mut chunk_text = String::new();
|
||||||
@@ -562,7 +702,12 @@ impl GrpcRouter {
|
|||||||
request: proto::GenerateRequest,
|
request: proto::GenerateRequest,
|
||||||
original_request: &ChatCompletionRequest,
|
original_request: &ChatCompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let mut stop_decoder = self.create_stop_decoder(original_request);
|
let mut stop_decoder = self.create_stop_decoder(
|
||||||
|
original_request.stop.as_ref(),
|
||||||
|
original_request.stop_token_ids.as_ref(),
|
||||||
|
original_request.skip_special_tokens,
|
||||||
|
original_request.no_stop_trim,
|
||||||
|
);
|
||||||
|
|
||||||
// Process streaming tokens
|
// Process streaming tokens
|
||||||
let mut grpc_stream = match client.generate(request).await {
|
let mut grpc_stream = match client.generate(request).await {
|
||||||
@@ -589,7 +734,7 @@ impl GrpcRouter {
|
|||||||
};
|
};
|
||||||
|
|
||||||
match gen_response.response {
|
match gen_response.response {
|
||||||
Some(proto::generate_response::Response::Chunk(chunk)) => {
|
Some(Chunk(chunk)) => {
|
||||||
// Process tokens and check if we should stop
|
// Process tokens and check if we should stop
|
||||||
let (chunk_text, should_stop) =
|
let (chunk_text, should_stop) =
|
||||||
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
|
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
|
||||||
@@ -599,7 +744,7 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Some(proto::generate_response::Response::Complete(_complete)) => {
|
Some(Complete(_complete)) => {
|
||||||
// Flush any remaining text
|
// Flush any remaining text
|
||||||
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
|
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
@@ -609,7 +754,7 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Some(proto::generate_response::Response::Error(error)) => {
|
Some(Error(error)) => {
|
||||||
error!("Generation error: {}", error.message);
|
error!("Generation error: {}", error.message);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -629,26 +774,19 @@ impl GrpcRouter {
|
|||||||
request: proto::GenerateRequest,
|
request: proto::GenerateRequest,
|
||||||
original_request: &ChatCompletionRequest,
|
original_request: &ChatCompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let mut stop_decoder = self.create_stop_decoder(original_request);
|
let mut stop_decoder = self.create_stop_decoder(
|
||||||
|
original_request.stop.as_ref(),
|
||||||
// Small helpers to log + return a uniform 500
|
original_request.stop_token_ids.as_ref(),
|
||||||
let fail_str = |msg: &'static str| -> Response {
|
original_request.skip_special_tokens,
|
||||||
error!("{}", msg);
|
original_request.no_stop_trim,
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response()
|
);
|
||||||
};
|
|
||||||
let fail_fmt = |prefix: &str, e: &dyn std::fmt::Display| -> Response {
|
|
||||||
error!("{}{}", prefix, e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("{}{}", prefix, e),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start generation
|
// Start generation
|
||||||
let mut stream = match client.generate(request).await {
|
let mut stream = match client.generate(request).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => return fail_fmt("Failed to start generation: ", &e),
|
Err(e) => {
|
||||||
|
return Self::internal_error_message(format!("Failed to start generation: {}", e))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Collect all responses (for n>1 support)
|
// Collect all responses (for n>1 support)
|
||||||
@@ -656,28 +794,33 @@ impl GrpcRouter {
|
|||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
Ok(gen_response) => match gen_response.response {
|
Ok(gen_response) => match gen_response.response {
|
||||||
Some(proto::generate_response::Response::Complete(complete)) => {
|
Some(Complete(complete)) => {
|
||||||
all_responses.push(complete);
|
all_responses.push(complete);
|
||||||
}
|
}
|
||||||
Some(proto::generate_response::Response::Error(err)) => {
|
Some(Error(err)) => {
|
||||||
error!("Generation failed for one choice: {}", err.message);
|
return Self::internal_error_message(format!(
|
||||||
return (
|
"Generation failed: {}",
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
err.message
|
||||||
format!("Generation failed: {}", err.message),
|
));
|
||||||
|
}
|
||||||
|
Some(Chunk(_)) => {
|
||||||
|
return Self::internal_error_static(
|
||||||
|
"Unexpected chunk response for non-streaming request",
|
||||||
)
|
)
|
||||||
.into_response();
|
|
||||||
}
|
}
|
||||||
Some(proto::generate_response::Response::Chunk(_)) => {
|
None => return Self::internal_error_static("Empty response from server"),
|
||||||
return fail_str("Unexpected chunk response for non-streaming request")
|
|
||||||
}
|
|
||||||
None => return fail_str("Empty response from server"),
|
|
||||||
},
|
},
|
||||||
Err(e) => return fail_fmt("Failed to get GenerateResponse: ", &e),
|
Err(e) => {
|
||||||
|
return Self::internal_error_message(format!(
|
||||||
|
"Failed to get GenerateResponse: {}",
|
||||||
|
e
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if all_responses.is_empty() {
|
if all_responses.is_empty() {
|
||||||
return fail_str("No responses from server");
|
return Self::internal_error_static("No responses from server");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process each response into a ChatChoice
|
// Process each response into a ChatChoice
|
||||||
@@ -689,12 +832,10 @@ impl GrpcRouter {
|
|||||||
{
|
{
|
||||||
Ok(choice) => choices.push(choice),
|
Ok(choice) => choices.push(choice),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to process choice {}: {}", index, e);
|
return Self::internal_error_message(format!(
|
||||||
return (
|
"Failed to process choice {}: {}",
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
index, e
|
||||||
format!("Failed to process choice {}: {}", index, e),
|
));
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -730,6 +871,127 @@ impl GrpcRouter {
|
|||||||
Json(response).into_response()
|
Json(response).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Submit request and handle non-streaming response for the `/generate` endpoint
|
||||||
|
async fn handle_non_streaming_generate(
|
||||||
|
&self,
|
||||||
|
mut client: SglangSchedulerClient,
|
||||||
|
request: proto::GenerateRequest,
|
||||||
|
original_request: &GenerateRequest,
|
||||||
|
request_id: String,
|
||||||
|
weight_version: String,
|
||||||
|
) -> Response {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let mut stream = match client.generate(request).await {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(e) => {
|
||||||
|
return Self::internal_error_message(format!("Failed to start generation: {}", e))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut final_completion: Option<proto::GenerateComplete> = None;
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(gen_response) => match gen_response.response {
|
||||||
|
Some(Complete(complete)) => {
|
||||||
|
final_completion = Some(complete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(Error(err)) => {
|
||||||
|
return Self::internal_error_message(format!(
|
||||||
|
"Generation failed: {}",
|
||||||
|
err.message
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Some(Chunk(_)) | None => continue,
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
return Self::internal_error_message(format!(
|
||||||
|
"Failed to receive generate response: {}",
|
||||||
|
e
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut complete = match final_completion {
|
||||||
|
Some(c) => c,
|
||||||
|
None => {
|
||||||
|
return Self::internal_error_static("No completion received from scheduler");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create stop decoder from sampling params
|
||||||
|
let params = original_request.sampling_params.as_ref();
|
||||||
|
let mut stop_decoder = self.create_stop_decoder(
|
||||||
|
params.and_then(|p| p.stop.as_ref()),
|
||||||
|
params.and_then(|p| p.stop_token_ids.as_ref()),
|
||||||
|
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
|
||||||
|
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Process tokens through stop decoder
|
||||||
|
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||||
|
Ok(outputs) => outputs,
|
||||||
|
Err(e) => {
|
||||||
|
return Self::internal_error_message(format!("Failed to process tokens: {}", e))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Accumulate text with early breaks
|
||||||
|
let mut decoded_text = String::new();
|
||||||
|
for output in outputs {
|
||||||
|
match output {
|
||||||
|
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
|
||||||
|
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||||
|
decoded_text.push_str(&t);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
SequenceDecoderOutput::Stopped => break,
|
||||||
|
SequenceDecoderOutput::Held => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush remaining text
|
||||||
|
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||||
|
decoded_text.push_str(&t);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_ids = complete.output_ids.clone();
|
||||||
|
|
||||||
|
// Build base meta_info using json! macro
|
||||||
|
let mut meta_info = json!({
|
||||||
|
"finish_reason": complete.finish_reason.clone(),
|
||||||
|
"prompt_tokens": complete.prompt_tokens,
|
||||||
|
"completion_tokens": complete.completion_tokens,
|
||||||
|
"cached_tokens": complete.cached_tokens,
|
||||||
|
"id": request_id,
|
||||||
|
"weight_version": weight_version,
|
||||||
|
"e2e_latency": start_time.elapsed().as_secs_f64(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let meta_obj = meta_info.as_object_mut().unwrap();
|
||||||
|
|
||||||
|
// Add matched_stop if present
|
||||||
|
if let Some(matched) = complete.matched_stop.take() {
|
||||||
|
use proto::generate_complete::MatchedStop;
|
||||||
|
let matched_value = match matched {
|
||||||
|
MatchedStop::MatchedTokenId(id) => json!(id),
|
||||||
|
MatchedStop::MatchedStopStr(s) => json!(s),
|
||||||
|
};
|
||||||
|
meta_obj.insert("matched_stop".to_string(), matched_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_body = json!({
|
||||||
|
"text": decoded_text,
|
||||||
|
"output_ids": output_ids,
|
||||||
|
"meta_info": meta_info,
|
||||||
|
});
|
||||||
|
|
||||||
|
Json(response_body).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert proto LogProbs to OpenAI ChatLogProbs format
|
/// Convert proto LogProbs to OpenAI ChatLogProbs format
|
||||||
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
|
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
|
||||||
fn convert_proto_to_openai_logprobs(
|
fn convert_proto_to_openai_logprobs(
|
||||||
@@ -803,7 +1065,7 @@ impl GrpcRouter {
|
|||||||
complete: &proto::GenerateComplete,
|
complete: &proto::GenerateComplete,
|
||||||
index: usize,
|
index: usize,
|
||||||
original_request: &ChatCompletionRequest,
|
original_request: &ChatCompletionRequest,
|
||||||
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
|
stop_decoder: &mut StopSequenceDecoder,
|
||||||
) -> Result<ChatChoice, String> {
|
) -> Result<ChatChoice, String> {
|
||||||
stop_decoder.reset();
|
stop_decoder.reset();
|
||||||
// Decode tokens
|
// Decode tokens
|
||||||
@@ -1002,11 +1264,11 @@ impl RouterTrait for GrpcRouter {
|
|||||||
|
|
||||||
async fn route_generate(
|
async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
_headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
_body: &GenerateRequest,
|
body: &GenerateRequest,
|
||||||
_model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
self.route_generate_impl(headers, body, model_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
|
|||||||
Reference in New Issue
Block a user