Rename api_num_spec_tokens -> num_api_spec_tokens (#458)

This commit is contained in:
Lianmin Zheng
2024-05-20 18:44:23 -07:00
committed by GitHub
parent 8dbdc018a3
commit ced77c6626
7 changed files with 27 additions and 27 deletions

View File

@@ -66,7 +66,7 @@ def run_program(
default_sampling_para,
chat_template=None,
stream=stream,
api_num_spec_tokens=program.api_num_spec_tokens,
num_api_spec_tokens=program.num_api_spec_tokens,
)
state = ProgramState(stream_executor)
@@ -178,7 +178,7 @@ class StreamExecutor:
default_sampling_para,
chat_template,
stream,
api_num_spec_tokens=None,
num_api_spec_tokens=None,
use_thread=True,
):
self.sid = uuid.uuid4().hex
@@ -210,7 +210,7 @@ class StreamExecutor:
self.fork_start_text_pos = None
# For speculative execution
self.api_num_spec_tokens = api_num_spec_tokens
self.num_api_spec_tokens = num_api_spec_tokens
self.speculated_text = ""
# Worker thread
@@ -399,7 +399,7 @@ class StreamExecutor:
if (
self.cur_role == "assistant"
and self.api_num_spec_tokens is not None
and self.num_api_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
@@ -444,7 +444,7 @@ class StreamExecutor:
nonlocal meta_info
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens
sampling_params.max_new_tokens, self.num_api_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
@@ -490,7 +490,7 @@ class StreamExecutor:
name = expr.name
if not self.stream:
if self.api_num_spec_tokens is None:
if self.num_api_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
@@ -517,7 +517,7 @@ class StreamExecutor:
self.variable_event[name].set()
else:
assert (
self.api_num_spec_tokens is None
self.num_api_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
@@ -580,7 +580,7 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.api_num_spec_tokens is not None
and self.num_api_spec_tokens is not None
and self.backend.is_chat_model
):
# Execute the stored lazy generation calls

View File

@@ -97,9 +97,9 @@ class SglSamplingParams:
class SglFunction:
def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
self.func = func
self.api_num_spec_tokens = api_num_spec_tokens
self.num_api_spec_tokens = num_api_spec_tokens
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None