Handle truncation errors (#436)

This commit is contained in:
Lianmin Zheng
2024-05-13 15:56:00 -07:00
committed by GitHub
parent 4231a42fa8
commit 5dc55a5f02
10 changed files with 44 additions and 41 deletions

View File

@@ -86,9 +86,9 @@ def run_program_batch(
if hasattr(backend, "endpoint"):
backend = backend.endpoint
# Extract prefix by tracing and cache it
if len(batch_arguments) > 1:
pin_program(program, backend)
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
cache_program(program, backend)
# Run all programs
if num_threads == "auto":
@@ -154,21 +154,12 @@ def run_program_batch(
return rets
def pin_program(program, backend):
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
# TODO: handle multiple backends
from sglang.lang.tracer import extract_prefix_by_tracing
def cache_program(program, backend):
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
prefix_rid = backend.cache_prefix(prefix)
program.pin_prefix_rid = prefix_rid
return prefix_rid
return None
def unpin_program(program, backend):
pass
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
backend.cache_prefix(prefix)
class StreamExecutor:
@@ -322,7 +313,7 @@ class StreamExecutor:
try:
self._execute(expr)
except Exception as e:
print(f"Error in stream_executor: {get_exception_traceback()}")
# print(f"Error in stream_executor: {get_exception_traceback()}")
error = e
break
self.queue.task_done()
@@ -702,9 +693,10 @@ class ProgramState:
return self.stream_executor.messages()
def sync(self):
ret = self.stream_executor.sync()
self.error = self.stream_executor.error
return ret
return self.stream_executor.sync()
def error(self):
return self.stream_executor.error
def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.stream:

View File

@@ -193,17 +193,11 @@ class SglFunction:
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def pin(self, backend=None):
from sglang.lang.interpreter import pin_program
def cache(self, backend=None):
from sglang.lang.interpreter import cache_program
backend = backend or global_config.default_backend
return pin_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
return cache_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func