From 7358fa64f7da3f18ce7512148d330755b0c1f1fe Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 22:10:17 +0000 Subject: [PATCH] Fix a bug in runtime backend --- python/sglang/lang/interpreter.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 4e70f942b..5949134a6 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -48,7 +48,10 @@ def run_internal(state, program, func_args, func_kwargs, sync): def run_program( program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False ): + if hasattr(backend, "endpoint"): + backend = backend.endpoint assert backend is not None, "Please specify a backend" + func_kwargs.update(program.bind_arguments) stream_executor = StreamExecutor( backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream @@ -74,6 +77,9 @@ def run_program_batch( num_threads, progress_bar, ): + if hasattr(backend, "endpoint"): + backend = backend.endpoint + # Extract prefix by tracing and cache it if len(batch_arguments) > 1: pin_program(program, backend) @@ -157,9 +163,6 @@ class StreamExecutor: self.default_sampling_para = default_sampling_para self.stream = stream - if hasattr(backend, "endpoint"): - self.backend = backend.endpoint - self.variables = {} # Dict[name: str -> value: str] self.variable_event = {} # Dict[name: str -> event: threading.Event] self.meta_info = {} # Dict[name: str -> info: str]