diff --git a/examples/usage/images/rag/max-tokens-fixed-rag-trace.png b/examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png similarity index 100% rename from examples/usage/images/rag/max-tokens-fixed-rag-trace.png rename to examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png diff --git a/examples/usage/trace_and_evaluate_rag_using_parea.ipynb b/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb similarity index 99% rename from examples/usage/trace_and_evaluate_rag_using_parea.ipynb rename to examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb index b25f2fb23..e7a574cbd 100644 --- a/examples/usage/trace_and_evaluate_rag_using_parea.ipynb +++ b/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb @@ -369,7 +369,7 @@ "\n", "The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n", "\n", - "![Fixed Max. Tokens](./images/rag/max-tokens-fixed-rag-trace.png)" + "![Fixed Max. Tokens](max-tokens-fixed-rag-trace.png)" ], "metadata": { "collapsed": false diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index e746d7f1d..062628bd3 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -16,7 +16,7 @@ class GlobalConfig: # Optimization configs self.eager_fill_image = False - self.enable_prefix_sharing = True + self.enable_precache_with_tracing = True self.enable_parallel_encoding = True self.enable_parallel_decoding = True diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 03406eb74..5bc51928c 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -86,9 +86,9 @@ def run_program_batch( if hasattr(backend, "endpoint"): backend = backend.endpoint - # Extract prefix by tracing and cache it - if len(batch_arguments) > 1: - pin_program(program, backend) + # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program. + if global_config.enable_precache_with_tracing and len(batch_arguments) > 1: + cache_program(program, backend) # Run all programs if num_threads == "auto": @@ -154,21 +154,12 @@ def run_program_batch( return rets -def pin_program(program, backend): - if global_config.enable_prefix_sharing and program.pin_prefix_rid is None: - # TODO: handle multiple backends - from sglang.lang.tracer import extract_prefix_by_tracing +def cache_program(program, backend): + from sglang.lang.tracer import extract_prefix_by_tracing - prefix = extract_prefix_by_tracing(program, backend) - if prefix and len(prefix) > 64: - prefix_rid = backend.cache_prefix(prefix) - program.pin_prefix_rid = prefix_rid - return prefix_rid - return None - - -def unpin_program(program, backend): - pass + prefix = extract_prefix_by_tracing(program, backend) + if prefix and len(prefix) > 64: + backend.cache_prefix(prefix) class StreamExecutor: @@ -322,7 +313,7 @@ class StreamExecutor: try: self._execute(expr) except Exception as e: - print(f"Error in stream_executor: {get_exception_traceback()}") + # print(f"Error in stream_executor: {get_exception_traceback()}") error = e break self.queue.task_done() @@ -702,9 +693,10 @@ class ProgramState: return self.stream_executor.messages() def sync(self): - ret = self.stream_executor.sync() - self.error = self.stream_executor.error - return ret + return self.stream_executor.sync() + + def error(self): + return self.stream_executor.error def text_iter(self, var_name: Optional[str] = None): if self.stream_executor.stream: diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 9895786dc..eaf92070c 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -193,17 +193,11 @@ class SglFunction: backend = backend or global_config.default_backend return trace_program(self, kwargs, backend) - def pin(self, backend=None): - from sglang.lang.interpreter import pin_program + def cache(self, backend=None): + from sglang.lang.interpreter import cache_program backend = backend or global_config.default_backend - return pin_program(self, backend) - - def unpin(self, backend=None): - from sglang.lang.interpreter import unpin_program - - backend = backend or global_config.default_backend - return unpin_program(self, backend) + return cache_program(self, backend) def compile(self, *, backend=None): from sglang.lang.compiler import compile_func diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index a2420fa93..88e7a1973 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -20,6 +20,16 @@ class FinishReason(IntEnum): LENGTH = auto() STOP_STR = auto() + def to_str(self): + if self == FinishReason.EOS_TOKEN: + return None + elif self == FinishReason.LENGTH: + return "length" + elif self == FinishReason.STOP_STR: + return "stop" + else: + raise ValueError(f"Invalid finish reason: {self}") + class Req: def __init__(self, rid, input_text, input_ids): diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index f9e7153a8..46b2c0c61 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -612,7 +612,7 @@ class ModelRpcServer: + len(req.output_ids) - req.prompt_tokens, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string + "finish_reason": req.finish_reason.to_str(), "hit_stop_str": req.hit_stop_str, } if req.return_logprob: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e8aa2d77c..81f009ce1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -98,7 +98,6 @@ class TokenizerManager: self.hf_config = get_config( self.model_path, trust_remote_code=server_args.trust_remote_code ) - self.context_len = get_context_length(self.hf_config) if is_multimodal_model(self.model_path): @@ -156,6 +155,12 @@ class TokenizerManager: else: input_ids = obj.input_ids + if len(input_ids) >= self.context_len: + raise ValueError( + f"The input ({len(input_ids)} tokens) is longer than the " + f"model's context length ({self.context_len} tokens)" + ) + sampling_params = SamplingParams(**obj.sampling_params) if sampling_params.max_new_tokens != 0: sampling_params.normalize(self.tokenizer) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 74bbd14c3..dae23f7aa 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -20,7 +20,7 @@ import requests import uvicorn import uvloop from fastapi import FastAPI, Request -from fastapi.responses import Response, StreamingResponse +from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.constrained import disable_cache @@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput): return StreamingResponse(stream_results(), media_type="text/event-stream") - ret = await tokenizer_manager.generate_request(obj).__anext__() - return ret + try: + ret = await tokenizer_manager.generate_request(obj).__anext__() + return ret + except ValueError as e: + return JSONResponse({"error": str(e)}, status_code=400) @app.post("/v1/completions") diff --git a/test/lang/test_bind_pin.py b/test/lang/test_bind_cache.py similarity index 91% rename from test/lang/test_bind_pin.py rename to test/lang/test_bind_cache.py index 626d6ff05..9cba14ce4 100644 --- a/test/lang/test_bind_pin.py +++ b/test/lang/test_bind_cache.py @@ -29,7 +29,7 @@ class TestBind(unittest.TestCase): tracer = few_shot_qa_2.trace() print(tracer.last_node.print_graph_dfs() + "\n") - def test_pin(self): + def test_cache(self): @sgl.function def few_shot_qa(s, prompt, question): s += prompt @@ -41,8 +41,7 @@ class TestBind(unittest.TestCase): few_shot_qa_2 = few_shot_qa.bind( prompt="Answer the following questions as if you were a 5-year-old kid.\n\n" ) - few_shot_qa_2.pin(self.backend) - few_shot_qa_2.unpin(self.backend) + few_shot_qa_2.cache(self.backend) if __name__ == "__main__": @@ -50,4 +49,4 @@ if __name__ == "__main__": # t = TestBind() # t.setUp() - # t.test_pin() + # t.test_cache()