Fix a bug in runtime backend

This commit is contained in:
Lianmin Zheng
2024-01-23 22:10:17 +00:00
parent 9a16fea012
commit 7358fa64f7

View File

@@ -48,7 +48,10 @@ def run_internal(state, program, func_args, func_kwargs, sync):
def run_program(
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
):
if hasattr(backend, "endpoint"):
backend = backend.endpoint
assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor(
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
@@ -74,6 +77,9 @@ def run_program_batch(
num_threads,
progress_bar,
):
if hasattr(backend, "endpoint"):
backend = backend.endpoint
# Extract prefix by tracing and cache it
if len(batch_arguments) > 1:
pin_program(program, backend)
@@ -157,9 +163,6 @@ class StreamExecutor:
self.default_sampling_para = default_sampling_para
self.stream = stream
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
self.meta_info = {} # Dict[name: str -> info: str]