From 3c1f5a92200e112a07d467771af879942d2dd440 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 17 Aug 2024 18:03:00 -0700 Subject: [PATCH] Fix duplicated imports in hf_transformers_utils.py (#1141) --- python/sglang/bench_serving.py | 8 ++++---- python/sglang/srt/hf_transformers_utils.py | 5 ----- python/sglang/test/test_utils.py | 10 +++++----- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 30a079e87..e2a99f9fd 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -222,7 +222,7 @@ async def async_request_openai_completions( return output -async def async_request_ginfer( +async def async_request_gserver( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: @@ -268,7 +268,7 @@ ASYNC_REQUEST_FUNCS = { "vllm": async_request_openai_completions, "lmdeploy": async_request_openai_completions, "trt": async_request_trt_llm, - "ginfer": async_request_ginfer, + "gserver": async_request_gserver, } @@ -790,7 +790,7 @@ def run_benchmark(args_: argparse.Namespace): "lmdeploy": 23333, "vllm": 8000, "trt": 8000, - "ginfer": 9988, + "gserver": 9988, }.get(args.backend, 30000) api_url = ( @@ -813,7 +813,7 @@ def run_benchmark(args_: argparse.Namespace): if args.model is None: print("Please provide a model using `--model` when using `trt` backend.") sys.exit(1) - elif args.backend == "ginfer": + elif args.backend == "gserver": api_url = args.base_url if args.base_url else f"{args.host}:{args.port}" args.model = args.model or "default" diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 76a8c9043..fb198fd73 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -44,11 +44,6 @@ except ImportError: from sglang.srt.utils import is_multimodal_model -_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - ChatGLMConfig.model_type: ChatGLMConfig, - DbrxConfig.model_type: DbrxConfig, -} - def download_from_hf(model_path: str): if os.path.exists(model_path): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 72fd54efe..9f6aa68ab 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -112,7 +112,7 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): return pred -def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None): +def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None): raise NotImplementedError() @@ -256,7 +256,7 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): "vllm", "outlines", "lightllm", - "ginfer", + "gserver", "guidance", "lmql", "srt-raw", @@ -277,7 +277,7 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): "lightllm": 22000, "lmql": 23000, "srt-raw": 30000, - "ginfer": 9988, + "gserver": 9988, } args.port = default_port.get(args.backend, None) return args @@ -313,8 +313,8 @@ def _get_call_generate(args: argparse.Namespace): return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "srt-raw": return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate") - elif args.backend == "ginfer": - return partial(call_generate_ginfer, url=f"{args.host}:{args.port}") + elif args.backend == "gserver": + return partial(call_generate_gserver, url=f"{args.host}:{args.port}") elif args.backend == "outlines": return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate") elif args.backend == "guidance":