diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 0471e37d9..269d32eaa 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -359,7 +359,6 @@ class Grok1ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: super().__init__() self.config = config diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 4f34e625e..a9dbe8275 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -106,7 +106,6 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: nn.Module.__init__(self) self.config = config diff --git a/python/sglang/srt/models/qwen2_eagle.py b/python/sglang/srt/models/qwen2_eagle.py index 01069ef48..12a4e6b3f 100644 --- a/python/sglang/srt/models/qwen2_eagle.py +++ b/python/sglang/srt/models/qwen2_eagle.py @@ -107,7 +107,6 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, - cache_config=None, ) -> None: nn.Module.__init__(self) self.config = config diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index ab472cc7a..94803c8e3 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -159,45 +159,6 @@ def call_generate_guidance( return rets if n > 1 else rets[0] -async def call_generate_lmql( - prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs -): - assert model is not None - import lmql - - if stop != None: - - @lmql.query(model=model) - async def program(question, max_tokens, stop): - '''lmql - """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop) - return ANSWER - ''' - - else: - - @lmql.query(model=model) - async def program(question, max_tokens): - '''lmql - """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens - return ANSWER - ''' - - tasks = [ - program( - question=prompt, - temperature=temperature, - max_tokens=max_tokens, - stop=stop, - max_len=max_len, - **kwargs, - ) - for _ in range(n) - ] - rets = await asyncio.gather(*tasks) - return rets if n > 1 else rets[0] - - def call_select_lightllm(context, choices, url=None): assert url is not None @@ -247,23 +208,6 @@ def call_select_guidance(context, choices, model=None): return choices.index(out["answer"]) -async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None): - assert model is not None - import lmql - - @lmql.query(model=model) - async def program(ctx, choices): - '''lmql - """{ctx}[ANSWER]""" where ANSWER in set(choices) - return ANSWER - ''' - - answer = await program( - ctx=context, choices=choices, temperature=temperature, max_len=max_len - ) - return choices.index(answer) - - def add_common_other_args_and_parse(parser: argparse.ArgumentParser): parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--host", type=str, default="http://127.0.0.1") @@ -278,7 +222,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): "lightllm", "gserver", "guidance", - "lmql", "srt-raw", "llama.cpp", ], @@ -295,7 +238,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): "vllm": 21000, "outlines": 21000, "lightllm": 22000, - "lmql": 23000, "srt-raw": 30000, "gserver": 9988, } @@ -343,11 +285,6 @@ def _get_call_generate(args: argparse.Namespace): call_generate = partial(call_generate_guidance, model=model) call_generate("Hello,", 1.0, 8, ".") return call_generate - elif args.backend == "lmql": - import lmql - - model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") - return partial(call_generate_lmql, model=model) else: raise ValueError(f"Invalid backend: {args.backend}") @@ -365,12 +302,6 @@ def _get_call_select(args: argparse.Namespace): call_select("Hello,", ["world", "earth"]) return call_select - - elif args.backend == "lmql": - import lmql - - model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") - return partial(call_select_lmql, model=model) else: raise ValueError(f"Invalid backend: {args.backend}")