Fix test cases (#6)

This commit is contained in:
Lianmin Zheng
2024-01-15 01:15:53 -08:00
committed by GitHub
parent 08ab2a1655
commit 4bd8233f2c
12 changed files with 90 additions and 19 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.1.0"
version = "0.1.2"
description = "A structured generation langauge for LLMs."
readme = "README.md"
requires-python = ">=3.8"
@@ -24,6 +24,10 @@ openai = ["openai>=1.0"]
anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]

View File

@@ -1,4 +1,4 @@
__version__ = "0.1.0"
__version__ = "0.1.2"
from sglang.api import *
from sglang.global_config import global_config

View File

@@ -17,13 +17,19 @@ from sglang.lang.ir import (
SglRoleEnd,
SglSelect,
)
from sglang.srt.server import Runtime
def function(func: Callable):
return SglFunction(func)
def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
from sglang.srt.server import Runtime
return Runtime(*args, **kwargs)
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend

View File

@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams, SglArgument
from sglang.lang.ir import SglArgument, SglSamplingParams
from sglang.utils import encode_image_base64, find_printable_text, http_request

View File

@@ -6,10 +6,10 @@ from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SglSamplingParams,
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
@@ -137,7 +137,6 @@ class CompiledFunction:
):
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SglSamplingParams(
@@ -182,9 +181,6 @@ class CompiledFunction:
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:

View File

@@ -12,7 +12,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
import tqdm
from sglang.global_config import global_config
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
@@ -89,7 +88,7 @@ def run_program_batch(
for arguments in batch_arguments:
rets.append(
run_program(
program, backend, (), arguments, default_sampling_para, False, False
program, backend, (), arguments, default_sampling_para, False, True
)
)
else:
@@ -108,7 +107,7 @@ def run_program_batch(
arguments,
default_sampling_para,
False,
False,
True,
)
)
if progress_bar:
@@ -478,7 +477,7 @@ class StreamExecutor:
"top_k",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
"ignore_eos",
"dtype",
"regex",
]:

View File

@@ -4,10 +4,10 @@ import logging
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

View File

@@ -12,7 +12,7 @@ class ServerArgs:
load_format: str = "auto"
tokenizer_mode: str = "auto"
trust_remote_code: bool = True
mem_fraction_static: float = 0.91
mem_fraction_static: Optional[float] = None
tp_size: int = 1
model_mode: List[str] = ()
schedule_heuristic: str = "lpm"
@@ -24,8 +24,11 @@ class ServerArgs:
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.tp_size > 1:
self.mem_fraction_static = 0.8
if self.mem_fraction_static is None:
if self.tp_size > 1:
self.mem_fraction_static = 0.8
else:
self.mem_fraction_static = 0.9
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):

View File

@@ -174,7 +174,7 @@ def test_tool_use():
def tool_use(s, lhs, rhs):
s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
s += "Question: What is the product of " + lhs + " and " + rhs + "?\n"
s += "Question: What is the product of " + str(lhs) + " and " + str(rhs) + "?\n"
s += (
"Answer: The answer is calculate("
+ sgl.gen("expression", stop=")")

View File

@@ -1,3 +1,6 @@
cp ../README.md ../LICENSE .
rm -rf dist
python3 -m build
python3 -m twine upload dist/*
rm -rf README.md LICENSE