router: Support parallel sampling num > 1 in grpc_server and non-stream handling (#10929)
This commit is contained in:
@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -11,6 +12,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
@@ -79,11 +81,9 @@ class GrpcReqState:
|
|||||||
last_completion_tokens: int = 1
|
last_completion_tokens: int = 1
|
||||||
|
|
||||||
# Streaming state
|
# Streaming state
|
||||||
last_output_offset: int = 0
|
|
||||||
stream_finished: bool = False
|
stream_finished: bool = False
|
||||||
|
|
||||||
# Output accumulation
|
# Token accumulation (for non-streaming)
|
||||||
text: str = ""
|
|
||||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||||
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||||
@@ -139,8 +139,6 @@ class GrpcRequestManager:
|
|||||||
self.is_pause_cond = asyncio.Condition()
|
self.is_pause_cond = asyncio.Condition()
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
self.request_counter = 0
|
|
||||||
self.request_counter_lock = asyncio.Lock()
|
|
||||||
self.last_receive_tstamp = time.time()
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
# Crash dump for debugging
|
# Crash dump for debugging
|
||||||
@@ -158,22 +156,133 @@ class GrpcRequestManager:
|
|||||||
obj: TokenizedGenerateReqInput,
|
obj: TokenizedGenerateReqInput,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||||
) -> asyncio.Queue:
|
):
|
||||||
"""
|
"""
|
||||||
Submit a generation request to the scheduler.
|
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
||||||
Returns a queue for streaming outputs.
|
|
||||||
|
This method implements the same two-phase approach as tokenizer_manager.py:
|
||||||
|
1. Phase 1: Send prefix caching request (max_new_tokens=0)
|
||||||
|
2. Phase 2: Send n generation requests that reuse the cached prefix
|
||||||
|
|
||||||
|
Yields individual responses for streaming, or aggregated responses for non-streaming.
|
||||||
"""
|
"""
|
||||||
|
n = getattr(obj.sampling_params, "n", 1)
|
||||||
|
|
||||||
|
if n <= 1:
|
||||||
|
async for response in self._handle_single_request(
|
||||||
|
obj, request_id, grpc_context
|
||||||
|
):
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
|
# N>1 handling - two-phase approach
|
||||||
|
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
|
||||||
|
|
||||||
|
# Generate base request ID if not provided
|
||||||
|
if request_id is None:
|
||||||
|
base_request_id = f"grpc-{uuid.uuid4().hex}"
|
||||||
|
else:
|
||||||
|
base_request_id = request_id
|
||||||
|
|
||||||
|
# Phase 1: Cache the common prefix
|
||||||
|
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
|
||||||
|
prefix_obj = copy.copy(obj)
|
||||||
|
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
|
||||||
|
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
|
||||||
|
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
|
||||||
|
|
||||||
|
# Send prefix caching request and consume response
|
||||||
|
async for _ in self._handle_single_request(
|
||||||
|
prefix_obj, f"{base_request_id}-prefix", grpc_context
|
||||||
|
):
|
||||||
|
# Consume prefix response (usually just one chunk with finish_reason)
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
|
||||||
|
|
||||||
|
# Phase 2: Generate n parallel requests
|
||||||
|
logger.debug(f"Phase 2: Generating {n} parallel requests")
|
||||||
|
generators = []
|
||||||
|
request_ids = []
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
# Create individual generation request
|
||||||
|
gen_obj = copy.copy(obj)
|
||||||
|
gen_obj.sampling_params = copy.copy(obj.sampling_params)
|
||||||
|
gen_obj.sampling_params.n = 1 # Each request generates 1 response
|
||||||
|
|
||||||
|
gen_request_id = f"{base_request_id}-{i}"
|
||||||
|
request_ids.append(gen_request_id)
|
||||||
|
|
||||||
|
# Start generation request
|
||||||
|
generators.append(
|
||||||
|
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle response aggregation
|
||||||
|
is_stream = getattr(obj, "stream", False)
|
||||||
|
|
||||||
|
if not is_stream:
|
||||||
|
# Non-streaming: collect all responses and return as batch
|
||||||
|
logger.debug(f"Non-streaming mode: collecting {n} responses")
|
||||||
|
responses = []
|
||||||
|
for generator in generators:
|
||||||
|
async for response in generator:
|
||||||
|
responses.append(response)
|
||||||
|
yield responses # Return all responses as a batch
|
||||||
|
else:
|
||||||
|
# Streaming mode: multiplex responses with index for ordering
|
||||||
|
logger.debug(f"Streaming mode: multiplexing {n} streams")
|
||||||
|
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
|
||||||
|
|
||||||
|
# Create async tasks for all generators
|
||||||
|
task_map = {}
|
||||||
|
for generator in generators:
|
||||||
|
task = asyncio.create_task(generator.__anext__())
|
||||||
|
task_map[task] = generator
|
||||||
|
|
||||||
|
# Process responses as they arrive
|
||||||
|
while task_map:
|
||||||
|
done, _ = await asyncio.wait(
|
||||||
|
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
for task in done:
|
||||||
|
generator = task_map.pop(task)
|
||||||
|
try:
|
||||||
|
response = await task
|
||||||
|
|
||||||
|
# Add index for client-side ordering
|
||||||
|
if isinstance(response, dict) and "meta_info" in response:
|
||||||
|
response_rid = response["meta_info"].get("id", "")
|
||||||
|
if response_rid in rid_to_index:
|
||||||
|
response["index"] = rid_to_index[response_rid]
|
||||||
|
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Create next task for this generator
|
||||||
|
next_task = asyncio.create_task(generator.__anext__())
|
||||||
|
task_map[next_task] = generator
|
||||||
|
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# This generator is finished
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _handle_single_request(
|
||||||
|
self,
|
||||||
|
obj: TokenizedGenerateReqInput,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||||
|
):
|
||||||
|
"""Handle a single request - core implementation without n>1 logic."""
|
||||||
# Generate request ID if not provided
|
# Generate request ID if not provided
|
||||||
if request_id is None:
|
if request_id is None:
|
||||||
async with self.request_counter_lock:
|
request_id = f"grpc-{uuid.uuid4().hex}"
|
||||||
request_id = f"grpc-{self.request_counter}"
|
|
||||||
self.request_counter += 1
|
|
||||||
|
|
||||||
obj.rid = request_id
|
obj.rid = request_id
|
||||||
|
|
||||||
|
# Create and register request state
|
||||||
# TODO: support log_request
|
# TODO: support log_request
|
||||||
|
|
||||||
# Create request state
|
|
||||||
state = GrpcReqState(
|
state = GrpcReqState(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
grpc_context=grpc_context,
|
grpc_context=grpc_context,
|
||||||
@@ -189,19 +298,51 @@ class GrpcRequestManager:
|
|||||||
state.session_id = obj.session_params.session_id
|
state.session_id = obj.session_params.session_id
|
||||||
state.is_session_request = True
|
state.is_session_request = True
|
||||||
|
|
||||||
# Register state
|
|
||||||
self.rid_to_state[request_id] = state
|
self.rid_to_state[request_id] = state
|
||||||
self.record_request_for_crash_dump(obj)
|
self.record_request_for_crash_dump(obj)
|
||||||
|
|
||||||
# Send to scheduler via ZMQ
|
|
||||||
try:
|
try:
|
||||||
|
# Send to scheduler - let exceptions bubble up to grpc_server.py
|
||||||
await self._send_to_scheduler(obj)
|
await self._send_to_scheduler(obj)
|
||||||
except Exception as e:
|
|
||||||
# Clean up on failure
|
|
||||||
del self.rid_to_state[request_id]
|
|
||||||
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
|
||||||
|
|
||||||
return state.out_queue
|
is_stream = getattr(obj, "stream", False)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Client cancelled - notify scheduler and exit
|
||||||
|
if grpc_context and grpc_context.cancelled():
|
||||||
|
await self.abort_request(request_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
|
||||||
|
|
||||||
|
if is_stream:
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Non-streaming: yield final response with accumulated tokens from state
|
||||||
|
if isinstance(response, dict) and response.get("finished", False):
|
||||||
|
if not is_stream:
|
||||||
|
final_response = response.copy()
|
||||||
|
final_response["token_ids"] = state.output_ids
|
||||||
|
yield final_response
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Timeout waiting for response - abort and cleanup
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout waiting for response for request {request_id}"
|
||||||
|
)
|
||||||
|
await self.abort_request(request_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Always clean up request state when exiting
|
||||||
|
self._cleanup_request_state(request_id)
|
||||||
|
|
||||||
|
def _cleanup_request_state(self, request_id: str):
|
||||||
|
"""Clean up local request state (does not notify scheduler)."""
|
||||||
|
if request_id in self.rid_to_state:
|
||||||
|
del self.rid_to_state[request_id]
|
||||||
|
|
||||||
async def embedding_request(
|
async def embedding_request(
|
||||||
self,
|
self,
|
||||||
@@ -214,9 +355,7 @@ class GrpcRequestManager:
|
|||||||
"""
|
"""
|
||||||
# Generate request ID if not provided
|
# Generate request ID if not provided
|
||||||
if request_id is None:
|
if request_id is None:
|
||||||
async with self.request_counter_lock:
|
request_id = f"grpc-embed-{uuid.uuid4().hex}"
|
||||||
request_id = f"grpc-embed-{self.request_counter}"
|
|
||||||
self.request_counter += 1
|
|
||||||
|
|
||||||
obj.rid = request_id
|
obj.rid = request_id
|
||||||
|
|
||||||
@@ -355,7 +494,6 @@ class GrpcRequestManager:
|
|||||||
# Extract output for this request
|
# Extract output for this request
|
||||||
output_data = {
|
output_data = {
|
||||||
"request_id": rid,
|
"request_id": rid,
|
||||||
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
|
||||||
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
||||||
"finished": batch_out.finished_reasons[i] is not None,
|
"finished": batch_out.finished_reasons[i] is not None,
|
||||||
"meta_info": {
|
"meta_info": {
|
||||||
@@ -367,6 +505,9 @@ class GrpcRequestManager:
|
|||||||
if batch_out.completion_tokens
|
if batch_out.completion_tokens
|
||||||
else 0
|
else 0
|
||||||
),
|
),
|
||||||
|
"cached_tokens": (
|
||||||
|
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
||||||
|
),
|
||||||
"finish_reason": (
|
"finish_reason": (
|
||||||
str(batch_out.finished_reasons[i])
|
str(batch_out.finished_reasons[i])
|
||||||
if batch_out.finished_reasons[i]
|
if batch_out.finished_reasons[i]
|
||||||
@@ -389,15 +530,10 @@ class GrpcRequestManager:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update state
|
# Update state for accumulation
|
||||||
if output_data["text"]:
|
|
||||||
state.text += output_data["text"][state.last_output_offset :]
|
|
||||||
state.last_output_offset = len(output_data["text"])
|
|
||||||
|
|
||||||
if output_data["token_ids"]:
|
if output_data["token_ids"]:
|
||||||
state.output_ids.extend(output_data["token_ids"])
|
state.output_ids.extend(output_data["token_ids"])
|
||||||
|
|
||||||
# Send to output queue
|
|
||||||
await state.out_queue.put(output_data)
|
await state.out_queue.put(output_data)
|
||||||
|
|
||||||
# Handle completion
|
# Handle completion
|
||||||
|
|||||||
@@ -181,20 +181,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
# Convert gRPC request to internal format
|
# Convert gRPC request to internal format
|
||||||
tokenized_req = self._convert_generate_request(request)
|
tokenized_req = self._convert_generate_request(request)
|
||||||
|
|
||||||
# Submit to request manager
|
# Submit to request manager (automatically handles n>1)
|
||||||
output_queue = await self.request_manager.generate_request(
|
response_generator = self.request_manager.generate_request(
|
||||||
obj=tokenized_req,
|
obj=tokenized_req,
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
grpc_context=context,
|
grpc_context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream outputs
|
async for output in response_generator:
|
||||||
while True:
|
# Handle batch responses (for n>1 non-streaming)
|
||||||
try:
|
if isinstance(output, list):
|
||||||
# Get output with timeout
|
for batch_output in output:
|
||||||
output = await asyncio.wait_for(output_queue.get(), timeout=4)
|
if "error" in batch_output:
|
||||||
|
yield sglang_scheduler_pb2.GenerateResponse(
|
||||||
# Check for errors
|
request_id=request.request_id,
|
||||||
|
error=sglang_scheduler_pb2.GenerateError(
|
||||||
|
message=batch_output["error"],
|
||||||
|
http_status_code=(
|
||||||
|
"500" if "abort" not in batch_output else "499"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# All non-error batch outputs are final responses
|
||||||
|
yield self._create_completion_response(
|
||||||
|
request.request_id, batch_output
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Handle single response (for streaming or n=1 non-streaming)
|
||||||
if "error" in output:
|
if "error" in output:
|
||||||
yield sglang_scheduler_pb2.GenerateResponse(
|
yield sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
@@ -205,27 +219,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
break
|
elif output.get("finished", False):
|
||||||
|
|
||||||
# Check if finished
|
|
||||||
if output.get("finished", False):
|
|
||||||
# Send completion
|
|
||||||
yield self._create_completion_response(
|
yield self._create_completion_response(
|
||||||
request.request_id, output
|
request.request_id, output
|
||||||
)
|
)
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# Send chunk
|
|
||||||
yield self._create_chunk_response(request.request_id, output)
|
yield self._create_chunk_response(request.request_id, output)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Check if context is still active
|
|
||||||
if context.cancelled():
|
|
||||||
# Abort the request
|
|
||||||
await self.request_manager.abort_request(request.request_id)
|
|
||||||
break
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
||||||
yield sglang_scheduler_pb2.GenerateResponse(
|
yield sglang_scheduler_pb2.GenerateResponse(
|
||||||
@@ -403,7 +403,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
return_logprob=grpc_req.return_logprob,
|
return_logprob=grpc_req.return_logprob,
|
||||||
logprob_start_len=grpc_req.logprob_start_len or -1,
|
logprob_start_len=grpc_req.logprob_start_len or -1,
|
||||||
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
||||||
stream=True, # Always stream for gRPC
|
stream=grpc_req.stream or False,
|
||||||
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
|
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
|
||||||
token_ids_logprob=(
|
token_ids_logprob=(
|
||||||
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||||
@@ -480,10 +480,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
return sglang_scheduler_pb2.GenerateResponse(
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
||||||
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
|
token_ids=output.get("token_ids", []),
|
||||||
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
||||||
completion_tokens=meta_info.get("completion_tokens", 0),
|
completion_tokens=meta_info.get("completion_tokens", 0),
|
||||||
cached_tokens=0,
|
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -122,6 +122,9 @@ message GenerateRequest {
|
|||||||
|
|
||||||
// For load balancing
|
// For load balancing
|
||||||
int32 dp_balance_id = 17;
|
int32 dp_balance_id = 17;
|
||||||
|
|
||||||
|
// Whether client wants streaming response
|
||||||
|
bool stream = 18;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TokenizedInput {
|
message TokenizedInput {
|
||||||
@@ -163,8 +166,8 @@ message GenerateResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message GenerateStreamChunk {
|
message GenerateStreamChunk {
|
||||||
// Generated token
|
// Generated tokens (incremental chunk)
|
||||||
int32 token_id = 1;
|
repeated int32 token_ids = 1;
|
||||||
|
|
||||||
// Cumulative counts
|
// Cumulative counts
|
||||||
int32 prompt_tokens = 2;
|
int32 prompt_tokens = 2;
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -83,7 +83,7 @@ class DisaggregatedParams(_message.Message):
|
|||||||
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateRequest(_message.Message):
|
class GenerateRequest(_message.Message):
|
||||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
|
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
|
||||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||||
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
@@ -101,6 +101,7 @@ class GenerateRequest(_message.Message):
|
|||||||
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||||
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
STREAM_FIELD_NUMBER: _ClassVar[int]
|
||||||
request_id: str
|
request_id: str
|
||||||
tokenized: TokenizedInput
|
tokenized: TokenizedInput
|
||||||
mm_inputs: MultimodalInputs
|
mm_inputs: MultimodalInputs
|
||||||
@@ -118,7 +119,8 @@ class GenerateRequest(_message.Message):
|
|||||||
lora_id: str
|
lora_id: str
|
||||||
data_parallel_rank: int
|
data_parallel_rank: int
|
||||||
dp_balance_id: int
|
dp_balance_id: int
|
||||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
|
stream: bool
|
||||||
|
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
|
||||||
|
|
||||||
class TokenizedInput(_message.Message):
|
class TokenizedInput(_message.Message):
|
||||||
__slots__ = ("original_text", "input_ids")
|
__slots__ = ("original_text", "input_ids")
|
||||||
@@ -161,20 +163,20 @@ class GenerateResponse(_message.Message):
|
|||||||
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateStreamChunk(_message.Message):
|
class GenerateStreamChunk(_message.Message):
|
||||||
__slots__ = ("token_id", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
|
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states")
|
||||||
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
token_id: int
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
cached_tokens: int
|
cached_tokens: int
|
||||||
logprobs: LogProbs
|
logprobs: LogProbs
|
||||||
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
||||||
def __init__(self, token_id: _Optional[int] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
|
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateComplete(_message.Message):
|
class GenerateComplete(_message.Message):
|
||||||
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
|
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ impl SglangSchedulerClient {
|
|||||||
logprob_start_len: -1,
|
logprob_start_len: -1,
|
||||||
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
|
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
|
||||||
return_hidden_states: body.return_hidden_states,
|
return_hidden_states: body.return_hidden_states,
|
||||||
|
stream: body.stream,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -367,14 +368,14 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_generate_stream_chunk() {
|
fn test_generate_stream_chunk() {
|
||||||
let chunk = proto::GenerateStreamChunk {
|
let chunk = proto::GenerateStreamChunk {
|
||||||
token_id: 1234,
|
token_ids: vec![1234, 5678],
|
||||||
prompt_tokens: 5,
|
prompt_tokens: 5,
|
||||||
completion_tokens: 2,
|
completion_tokens: 2,
|
||||||
cached_tokens: 3,
|
cached_tokens: 3,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(chunk.token_id, 1234);
|
assert_eq!(chunk.token_ids, vec![1234, 5678]);
|
||||||
assert_eq!(chunk.prompt_tokens, 5);
|
assert_eq!(chunk.prompt_tokens, 5);
|
||||||
assert_eq!(chunk.completion_tokens, 2);
|
assert_eq!(chunk.completion_tokens, 2);
|
||||||
assert_eq!(chunk.cached_tokens, 3);
|
assert_eq!(chunk.cached_tokens, 3);
|
||||||
|
|||||||
@@ -122,6 +122,9 @@ message GenerateRequest {
|
|||||||
|
|
||||||
// For load balancing
|
// For load balancing
|
||||||
int32 dp_balance_id = 17;
|
int32 dp_balance_id = 17;
|
||||||
|
|
||||||
|
// Whether client wants streaming response
|
||||||
|
bool stream = 18;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TokenizedInput {
|
message TokenizedInput {
|
||||||
@@ -163,8 +166,8 @@ message GenerateResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message GenerateStreamChunk {
|
message GenerateStreamChunk {
|
||||||
// Generated token
|
// Generated tokens (incremental chunk)
|
||||||
int32 token_id = 1;
|
repeated int32 token_ids = 1;
|
||||||
|
|
||||||
// Cumulative counts
|
// Cumulative counts
|
||||||
int32 prompt_tokens = 2;
|
int32 prompt_tokens = 2;
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ impl GrpcRouter {
|
|||||||
debug!("Selected worker: {}", worker.url());
|
debug!("Selected worker: {}", worker.url());
|
||||||
|
|
||||||
// Step 2: Get gRPC client for worker (fail fast if can't connect)
|
// Step 2: Get gRPC client for worker (fail fast if can't connect)
|
||||||
|
// TODO(CahterineSue): manage grpc connection in worker. (it should be simpler here)
|
||||||
let client = match self.get_or_create_grpc_client(worker.url()).await {
|
let client = match self.get_or_create_grpc_client(worker.url()).await {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -249,7 +250,7 @@ impl GrpcRouter {
|
|||||||
|
|
||||||
// Step 6: Build the base gRPC request
|
// Step 6: Build the base gRPC request
|
||||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||||
let base_request = match client.build_generate_request(
|
let request = match client.build_generate_request(
|
||||||
request_id,
|
request_id,
|
||||||
body,
|
body,
|
||||||
processed_messages.text.clone(),
|
processed_messages.text.clone(),
|
||||||
@@ -268,11 +269,11 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Step 7: Handle streaming vs non-streaming
|
||||||
if body.stream {
|
if body.stream {
|
||||||
self.handle_streaming_chat(client, base_request, body).await
|
self.handle_streaming_chat(client, request, body).await
|
||||||
} else {
|
} else {
|
||||||
self.handle_non_streaming_chat(client, base_request, body)
|
self.handle_non_streaming_chat(client, request, body).await
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user