Handle truncation errors (#436)
This commit is contained in:
@@ -16,7 +16,7 @@ class GlobalConfig:
|
||||
|
||||
# Optimization configs
|
||||
self.eager_fill_image = False
|
||||
self.enable_prefix_sharing = True
|
||||
self.enable_precache_with_tracing = True
|
||||
self.enable_parallel_encoding = True
|
||||
self.enable_parallel_decoding = True
|
||||
|
||||
|
||||
@@ -86,9 +86,9 @@ def run_program_batch(
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_arguments) > 1:
|
||||
pin_program(program, backend)
|
||||
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
||||
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
||||
cache_program(program, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
@@ -154,21 +154,12 @@ def run_program_batch(
|
||||
return rets
|
||||
|
||||
|
||||
def pin_program(program, backend):
|
||||
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
|
||||
# TODO: handle multiple backends
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
def cache_program(program, backend):
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
|
||||
prefix = extract_prefix_by_tracing(program, backend)
|
||||
if prefix and len(prefix) > 64:
|
||||
prefix_rid = backend.cache_prefix(prefix)
|
||||
program.pin_prefix_rid = prefix_rid
|
||||
return prefix_rid
|
||||
return None
|
||||
|
||||
|
||||
def unpin_program(program, backend):
|
||||
pass
|
||||
prefix = extract_prefix_by_tracing(program, backend)
|
||||
if prefix and len(prefix) > 64:
|
||||
backend.cache_prefix(prefix)
|
||||
|
||||
|
||||
class StreamExecutor:
|
||||
@@ -322,7 +313,7 @@ class StreamExecutor:
|
||||
try:
|
||||
self._execute(expr)
|
||||
except Exception as e:
|
||||
print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
error = e
|
||||
break
|
||||
self.queue.task_done()
|
||||
@@ -702,9 +693,10 @@ class ProgramState:
|
||||
return self.stream_executor.messages()
|
||||
|
||||
def sync(self):
|
||||
ret = self.stream_executor.sync()
|
||||
self.error = self.stream_executor.error
|
||||
return ret
|
||||
return self.stream_executor.sync()
|
||||
|
||||
def error(self):
|
||||
return self.stream_executor.error
|
||||
|
||||
def text_iter(self, var_name: Optional[str] = None):
|
||||
if self.stream_executor.stream:
|
||||
|
||||
@@ -193,17 +193,11 @@ class SglFunction:
|
||||
backend = backend or global_config.default_backend
|
||||
return trace_program(self, kwargs, backend)
|
||||
|
||||
def pin(self, backend=None):
|
||||
from sglang.lang.interpreter import pin_program
|
||||
def cache(self, backend=None):
|
||||
from sglang.lang.interpreter import cache_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return pin_program(self, backend)
|
||||
|
||||
def unpin(self, backend=None):
|
||||
from sglang.lang.interpreter import unpin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return unpin_program(self, backend)
|
||||
return cache_program(self, backend)
|
||||
|
||||
def compile(self, *, backend=None):
|
||||
from sglang.lang.compiler import compile_func
|
||||
|
||||
@@ -20,6 +20,16 @@ class FinishReason(IntEnum):
|
||||
LENGTH = auto()
|
||||
STOP_STR = auto()
|
||||
|
||||
def to_str(self):
|
||||
if self == FinishReason.EOS_TOKEN:
|
||||
return None
|
||||
elif self == FinishReason.LENGTH:
|
||||
return "length"
|
||||
elif self == FinishReason.STOP_STR:
|
||||
return "stop"
|
||||
else:
|
||||
raise ValueError(f"Invalid finish reason: {self}")
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, rid, input_text, input_ids):
|
||||
|
||||
@@ -612,7 +612,7 @@ class ModelRpcServer:
|
||||
+ len(req.output_ids)
|
||||
- req.prompt_tokens,
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
|
||||
"finish_reason": req.finish_reason.to_str(),
|
||||
"hit_stop_str": req.hit_stop_str,
|
||||
}
|
||||
if req.return_logprob:
|
||||
|
||||
@@ -98,7 +98,6 @@ class TokenizerManager:
|
||||
self.hf_config = get_config(
|
||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
if is_multimodal_model(self.model_path):
|
||||
@@ -156,6 +155,12 @@ class TokenizerManager:
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)"
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
|
||||
@@ -20,7 +20,7 @@ import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.constrained import disable_cache
|
||||
@@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput):
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
return ret
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return JSONResponse({"error": str(e)}, status_code=400)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
|
||||
Reference in New Issue
Block a user