Handle truncation errors (#436)

This commit is contained in:
Lianmin Zheng
2024-05-13 15:56:00 -07:00
committed by GitHub
parent 4231a42fa8
commit 5dc55a5f02
10 changed files with 44 additions and 41 deletions

View File

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 132 KiB

View File

@@ -369,7 +369,7 @@
"\n", "\n",
"The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n", "The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n",
"\n", "\n",
"![Fixed Max. Tokens](./images/rag/max-tokens-fixed-rag-trace.png)" "![Fixed Max. Tokens](max-tokens-fixed-rag-trace.png)"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false

View File

@@ -16,7 +16,7 @@ class GlobalConfig:
# Optimization configs # Optimization configs
self.eager_fill_image = False self.eager_fill_image = False
self.enable_prefix_sharing = True self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True self.enable_parallel_encoding = True
self.enable_parallel_decoding = True self.enable_parallel_decoding = True

View File

@@ -86,9 +86,9 @@ def run_program_batch(
if hasattr(backend, "endpoint"): if hasattr(backend, "endpoint"):
backend = backend.endpoint backend = backend.endpoint
# Extract prefix by tracing and cache it # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
if len(batch_arguments) > 1: if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
pin_program(program, backend) cache_program(program, backend)
# Run all programs # Run all programs
if num_threads == "auto": if num_threads == "auto":
@@ -154,21 +154,12 @@ def run_program_batch(
return rets return rets
def pin_program(program, backend): def cache_program(program, backend):
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None: from sglang.lang.tracer import extract_prefix_by_tracing
# TODO: handle multiple backends
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend) prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64: if prefix and len(prefix) > 64:
prefix_rid = backend.cache_prefix(prefix) backend.cache_prefix(prefix)
program.pin_prefix_rid = prefix_rid
return prefix_rid
return None
def unpin_program(program, backend):
pass
class StreamExecutor: class StreamExecutor:
@@ -322,7 +313,7 @@ class StreamExecutor:
try: try:
self._execute(expr) self._execute(expr)
except Exception as e: except Exception as e:
print(f"Error in stream_executor: {get_exception_traceback()}") # print(f"Error in stream_executor: {get_exception_traceback()}")
error = e error = e
break break
self.queue.task_done() self.queue.task_done()
@@ -702,9 +693,10 @@ class ProgramState:
return self.stream_executor.messages() return self.stream_executor.messages()
def sync(self): def sync(self):
ret = self.stream_executor.sync() return self.stream_executor.sync()
self.error = self.stream_executor.error
return ret def error(self):
return self.stream_executor.error
def text_iter(self, var_name: Optional[str] = None): def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.stream: if self.stream_executor.stream:

View File

@@ -193,17 +193,11 @@ class SglFunction:
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend) return trace_program(self, kwargs, backend)
def pin(self, backend=None): def cache(self, backend=None):
from sglang.lang.interpreter import pin_program from sglang.lang.interpreter import cache_program
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
return pin_program(self, backend) return cache_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
def compile(self, *, backend=None): def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func from sglang.lang.compiler import compile_func

View File

@@ -20,6 +20,16 @@ class FinishReason(IntEnum):
LENGTH = auto() LENGTH = auto()
STOP_STR = auto() STOP_STR = auto()
def to_str(self):
if self == FinishReason.EOS_TOKEN:
return None
elif self == FinishReason.LENGTH:
return "length"
elif self == FinishReason.STOP_STR:
return "stop"
else:
raise ValueError(f"Invalid finish reason: {self}")
class Req: class Req:
def __init__(self, rid, input_text, input_ids): def __init__(self, rid, input_text, input_ids):

View File

@@ -612,7 +612,7 @@ class ModelRpcServer:
+ len(req.output_ids) + len(req.output_ids)
- req.prompt_tokens, - req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string "finish_reason": req.finish_reason.to_str(),
"hit_stop_str": req.hit_stop_str, "hit_stop_str": req.hit_stop_str,
} }
if req.return_logprob: if req.return_logprob:

View File

@@ -98,7 +98,6 @@ class TokenizerManager:
self.hf_config = get_config( self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code self.model_path, trust_remote_code=server_args.trust_remote_code
) )
self.context_len = get_context_length(self.hf_config) self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path): if is_multimodal_model(self.model_path):
@@ -156,6 +155,12 @@ class TokenizerManager:
else: else:
input_ids = obj.input_ids input_ids = obj.input_ids
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)"
)
sampling_params = SamplingParams(**obj.sampling_params) sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0: if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)

View File

@@ -20,7 +20,7 @@ import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache from sglang.srt.constrained import disable_cache
@@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput):
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__() try:
return ret ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret
except ValueError as e:
return JSONResponse({"error": str(e)}, status_code=400)
@app.post("/v1/completions") @app.post("/v1/completions")

View File

@@ -29,7 +29,7 @@ class TestBind(unittest.TestCase):
tracer = few_shot_qa_2.trace() tracer = few_shot_qa_2.trace()
print(tracer.last_node.print_graph_dfs() + "\n") print(tracer.last_node.print_graph_dfs() + "\n")
def test_pin(self): def test_cache(self):
@sgl.function @sgl.function
def few_shot_qa(s, prompt, question): def few_shot_qa(s, prompt, question):
s += prompt s += prompt
@@ -41,8 +41,7 @@ class TestBind(unittest.TestCase):
few_shot_qa_2 = few_shot_qa.bind( few_shot_qa_2 = few_shot_qa.bind(
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n" prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
) )
few_shot_qa_2.pin(self.backend) few_shot_qa_2.cache(self.backend)
few_shot_qa_2.unpin(self.backend)
if __name__ == "__main__": if __name__ == "__main__":
@@ -50,4 +49,4 @@ if __name__ == "__main__":
# t = TestBind() # t = TestBind()
# t.setUp() # t.setUp()
# t.test_pin() # t.test_cache()