Fix test cases (#6)

This commit is contained in:
Lianmin Zheng
2024-01-15 01:15:53 -08:00
committed by GitHub
parent 08ab2a1655
commit 4bd8233f2c
12 changed files with 90 additions and 19 deletions

View File

@@ -6,10 +6,10 @@ from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SglSamplingParams,
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
@@ -137,7 +137,6 @@ class CompiledFunction:
):
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SglSamplingParams(
@@ -182,9 +181,6 @@ class CompiledFunction:
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:

View File

@@ -12,7 +12,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
import tqdm
from sglang.global_config import global_config
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
@@ -89,7 +88,7 @@ def run_program_batch(
for arguments in batch_arguments:
rets.append(
run_program(
program, backend, (), arguments, default_sampling_para, False, False
program, backend, (), arguments, default_sampling_para, False, True
)
)
else:
@@ -108,7 +107,7 @@ def run_program_batch(
arguments,
default_sampling_para,
False,
False,
True,
)
)
if progress_bar:
@@ -478,7 +477,7 @@ class StreamExecutor:
"top_k",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
"ignore_eos",
"dtype",
"regex",
]: