Fix a bug in runtime backend
This commit is contained in:
@@ -48,7 +48,10 @@ def run_internal(state, program, func_args, func_kwargs, sync):
|
|||||||
def run_program(
|
def run_program(
|
||||||
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
|
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
|
||||||
):
|
):
|
||||||
|
if hasattr(backend, "endpoint"):
|
||||||
|
backend = backend.endpoint
|
||||||
assert backend is not None, "Please specify a backend"
|
assert backend is not None, "Please specify a backend"
|
||||||
|
|
||||||
func_kwargs.update(program.bind_arguments)
|
func_kwargs.update(program.bind_arguments)
|
||||||
stream_executor = StreamExecutor(
|
stream_executor = StreamExecutor(
|
||||||
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
|
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
|
||||||
@@ -74,6 +77,9 @@ def run_program_batch(
|
|||||||
num_threads,
|
num_threads,
|
||||||
progress_bar,
|
progress_bar,
|
||||||
):
|
):
|
||||||
|
if hasattr(backend, "endpoint"):
|
||||||
|
backend = backend.endpoint
|
||||||
|
|
||||||
# Extract prefix by tracing and cache it
|
# Extract prefix by tracing and cache it
|
||||||
if len(batch_arguments) > 1:
|
if len(batch_arguments) > 1:
|
||||||
pin_program(program, backend)
|
pin_program(program, backend)
|
||||||
@@ -157,9 +163,6 @@ class StreamExecutor:
|
|||||||
self.default_sampling_para = default_sampling_para
|
self.default_sampling_para = default_sampling_para
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
if hasattr(backend, "endpoint"):
|
|
||||||
self.backend = backend.endpoint
|
|
||||||
|
|
||||||
self.variables = {} # Dict[name: str -> value: str]
|
self.variables = {} # Dict[name: str -> value: str]
|
||||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||||
self.meta_info = {} # Dict[name: str -> info: str]
|
self.meta_info = {} # Dict[name: str -> info: str]
|
||||||
|
|||||||
Reference in New Issue
Block a user