Handle truncation errors (#436)
This commit is contained in:
|
Before Width: | Height: | Size: 132 KiB After Width: | Height: | Size: 132 KiB |
@@ -369,7 +369,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n",
|
"The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n",
|
||||||
"\n",
|
"\n",
|
||||||
""
|
""
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false
|
||||||
@@ -16,7 +16,7 @@ class GlobalConfig:
|
|||||||
|
|
||||||
# Optimization configs
|
# Optimization configs
|
||||||
self.eager_fill_image = False
|
self.eager_fill_image = False
|
||||||
self.enable_prefix_sharing = True
|
self.enable_precache_with_tracing = True
|
||||||
self.enable_parallel_encoding = True
|
self.enable_parallel_encoding = True
|
||||||
self.enable_parallel_decoding = True
|
self.enable_parallel_decoding = True
|
||||||
|
|
||||||
|
|||||||
@@ -86,9 +86,9 @@ def run_program_batch(
|
|||||||
if hasattr(backend, "endpoint"):
|
if hasattr(backend, "endpoint"):
|
||||||
backend = backend.endpoint
|
backend = backend.endpoint
|
||||||
|
|
||||||
# Extract prefix by tracing and cache it
|
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
||||||
if len(batch_arguments) > 1:
|
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
||||||
pin_program(program, backend)
|
cache_program(program, backend)
|
||||||
|
|
||||||
# Run all programs
|
# Run all programs
|
||||||
if num_threads == "auto":
|
if num_threads == "auto":
|
||||||
@@ -154,21 +154,12 @@ def run_program_batch(
|
|||||||
return rets
|
return rets
|
||||||
|
|
||||||
|
|
||||||
def pin_program(program, backend):
|
def cache_program(program, backend):
|
||||||
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
|
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||||
# TODO: handle multiple backends
|
|
||||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
|
||||||
|
|
||||||
prefix = extract_prefix_by_tracing(program, backend)
|
prefix = extract_prefix_by_tracing(program, backend)
|
||||||
if prefix and len(prefix) > 64:
|
if prefix and len(prefix) > 64:
|
||||||
prefix_rid = backend.cache_prefix(prefix)
|
backend.cache_prefix(prefix)
|
||||||
program.pin_prefix_rid = prefix_rid
|
|
||||||
return prefix_rid
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def unpin_program(program, backend):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StreamExecutor:
|
class StreamExecutor:
|
||||||
@@ -322,7 +313,7 @@ class StreamExecutor:
|
|||||||
try:
|
try:
|
||||||
self._execute(expr)
|
self._execute(expr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in stream_executor: {get_exception_traceback()}")
|
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||||
error = e
|
error = e
|
||||||
break
|
break
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
@@ -702,9 +693,10 @@ class ProgramState:
|
|||||||
return self.stream_executor.messages()
|
return self.stream_executor.messages()
|
||||||
|
|
||||||
def sync(self):
|
def sync(self):
|
||||||
ret = self.stream_executor.sync()
|
return self.stream_executor.sync()
|
||||||
self.error = self.stream_executor.error
|
|
||||||
return ret
|
def error(self):
|
||||||
|
return self.stream_executor.error
|
||||||
|
|
||||||
def text_iter(self, var_name: Optional[str] = None):
|
def text_iter(self, var_name: Optional[str] = None):
|
||||||
if self.stream_executor.stream:
|
if self.stream_executor.stream:
|
||||||
|
|||||||
@@ -193,17 +193,11 @@ class SglFunction:
|
|||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
return trace_program(self, kwargs, backend)
|
return trace_program(self, kwargs, backend)
|
||||||
|
|
||||||
def pin(self, backend=None):
|
def cache(self, backend=None):
|
||||||
from sglang.lang.interpreter import pin_program
|
from sglang.lang.interpreter import cache_program
|
||||||
|
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
return pin_program(self, backend)
|
return cache_program(self, backend)
|
||||||
|
|
||||||
def unpin(self, backend=None):
|
|
||||||
from sglang.lang.interpreter import unpin_program
|
|
||||||
|
|
||||||
backend = backend or global_config.default_backend
|
|
||||||
return unpin_program(self, backend)
|
|
||||||
|
|
||||||
def compile(self, *, backend=None):
|
def compile(self, *, backend=None):
|
||||||
from sglang.lang.compiler import compile_func
|
from sglang.lang.compiler import compile_func
|
||||||
|
|||||||
@@ -20,6 +20,16 @@ class FinishReason(IntEnum):
|
|||||||
LENGTH = auto()
|
LENGTH = auto()
|
||||||
STOP_STR = auto()
|
STOP_STR = auto()
|
||||||
|
|
||||||
|
def to_str(self):
|
||||||
|
if self == FinishReason.EOS_TOKEN:
|
||||||
|
return None
|
||||||
|
elif self == FinishReason.LENGTH:
|
||||||
|
return "length"
|
||||||
|
elif self == FinishReason.STOP_STR:
|
||||||
|
return "stop"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid finish reason: {self}")
|
||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
def __init__(self, rid, input_text, input_ids):
|
def __init__(self, rid, input_text, input_ids):
|
||||||
|
|||||||
@@ -612,7 +612,7 @@ class ModelRpcServer:
|
|||||||
+ len(req.output_ids)
|
+ len(req.output_ids)
|
||||||
- req.prompt_tokens,
|
- req.prompt_tokens,
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
|
"finish_reason": req.finish_reason.to_str(),
|
||||||
"hit_stop_str": req.hit_stop_str,
|
"hit_stop_str": req.hit_stop_str,
|
||||||
}
|
}
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ class TokenizerManager:
|
|||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
self.model_path, trust_remote_code=server_args.trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
self.context_len = get_context_length(self.hf_config)
|
self.context_len = get_context_length(self.hf_config)
|
||||||
|
|
||||||
if is_multimodal_model(self.model_path):
|
if is_multimodal_model(self.model_path):
|
||||||
@@ -156,6 +155,12 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
input_ids = obj.input_ids
|
input_ids = obj.input_ids
|
||||||
|
|
||||||
|
if len(input_ids) >= self.context_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||||
|
f"model's context length ({self.context_len} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params)
|
sampling_params = SamplingParams(**obj.sampling_params)
|
||||||
if sampling_params.max_new_tokens != 0:
|
if sampling_params.max_new_tokens != 0:
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import requests
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.srt.constrained import disable_cache
|
from sglang.srt.constrained import disable_cache
|
||||||
@@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput):
|
|||||||
|
|
||||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||||
|
|
||||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
try:
|
||||||
return ret
|
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
return JSONResponse({"error": str(e)}, status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class TestBind(unittest.TestCase):
|
|||||||
tracer = few_shot_qa_2.trace()
|
tracer = few_shot_qa_2.trace()
|
||||||
print(tracer.last_node.print_graph_dfs() + "\n")
|
print(tracer.last_node.print_graph_dfs() + "\n")
|
||||||
|
|
||||||
def test_pin(self):
|
def test_cache(self):
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def few_shot_qa(s, prompt, question):
|
def few_shot_qa(s, prompt, question):
|
||||||
s += prompt
|
s += prompt
|
||||||
@@ -41,8 +41,7 @@ class TestBind(unittest.TestCase):
|
|||||||
few_shot_qa_2 = few_shot_qa.bind(
|
few_shot_qa_2 = few_shot_qa.bind(
|
||||||
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
|
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
|
||||||
)
|
)
|
||||||
few_shot_qa_2.pin(self.backend)
|
few_shot_qa_2.cache(self.backend)
|
||||||
few_shot_qa_2.unpin(self.backend)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -50,4 +49,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# t = TestBind()
|
# t = TestBind()
|
||||||
# t.setUp()
|
# t.setUp()
|
||||||
# t.test_pin()
|
# t.test_cache()
|
||||||
Reference in New Issue
Block a user