Handle truncation errors (#436)
This commit is contained in:
@@ -86,9 +86,9 @@ def run_program_batch(
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_arguments) > 1:
|
||||
pin_program(program, backend)
|
||||
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
||||
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
||||
cache_program(program, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
@@ -154,21 +154,12 @@ def run_program_batch(
|
||||
return rets
|
||||
|
||||
|
||||
def pin_program(program, backend):
|
||||
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
|
||||
# TODO: handle multiple backends
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
def cache_program(program, backend):
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
|
||||
prefix = extract_prefix_by_tracing(program, backend)
|
||||
if prefix and len(prefix) > 64:
|
||||
prefix_rid = backend.cache_prefix(prefix)
|
||||
program.pin_prefix_rid = prefix_rid
|
||||
return prefix_rid
|
||||
return None
|
||||
|
||||
|
||||
def unpin_program(program, backend):
|
||||
pass
|
||||
prefix = extract_prefix_by_tracing(program, backend)
|
||||
if prefix and len(prefix) > 64:
|
||||
backend.cache_prefix(prefix)
|
||||
|
||||
|
||||
class StreamExecutor:
|
||||
@@ -322,7 +313,7 @@ class StreamExecutor:
|
||||
try:
|
||||
self._execute(expr)
|
||||
except Exception as e:
|
||||
print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
error = e
|
||||
break
|
||||
self.queue.task_done()
|
||||
@@ -702,9 +693,10 @@ class ProgramState:
|
||||
return self.stream_executor.messages()
|
||||
|
||||
def sync(self):
|
||||
ret = self.stream_executor.sync()
|
||||
self.error = self.stream_executor.error
|
||||
return ret
|
||||
return self.stream_executor.sync()
|
||||
|
||||
def error(self):
|
||||
return self.stream_executor.error
|
||||
|
||||
def text_iter(self, var_name: Optional[str] = None):
|
||||
if self.stream_executor.stream:
|
||||
|
||||
@@ -193,17 +193,11 @@ class SglFunction:
|
||||
backend = backend or global_config.default_backend
|
||||
return trace_program(self, kwargs, backend)
|
||||
|
||||
def pin(self, backend=None):
|
||||
from sglang.lang.interpreter import pin_program
|
||||
def cache(self, backend=None):
|
||||
from sglang.lang.interpreter import cache_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return pin_program(self, backend)
|
||||
|
||||
def unpin(self, backend=None):
|
||||
from sglang.lang.interpreter import unpin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return unpin_program(self, backend)
|
||||
return cache_program(self, backend)
|
||||
|
||||
def compile(self, *, backend=None):
|
||||
from sglang.lang.compiler import compile_func
|
||||
|
||||
Reference in New Issue
Block a user