remove cache configs in model definitions (#4031)
This commit is contained in:
@@ -359,7 +359,6 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config=None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@@ -106,7 +106,6 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
|
|||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config=None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@@ -107,7 +107,6 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
|
|||||||
self,
|
self,
|
||||||
config: Qwen2Config,
|
config: Qwen2Config,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config=None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@@ -159,45 +159,6 @@ def call_generate_guidance(
|
|||||||
return rets if n > 1 else rets[0]
|
return rets if n > 1 else rets[0]
|
||||||
|
|
||||||
|
|
||||||
async def call_generate_lmql(
|
|
||||||
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
|
|
||||||
):
|
|
||||||
assert model is not None
|
|
||||||
import lmql
|
|
||||||
|
|
||||||
if stop != None:
|
|
||||||
|
|
||||||
@lmql.query(model=model)
|
|
||||||
async def program(question, max_tokens, stop):
|
|
||||||
'''lmql
|
|
||||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
|
|
||||||
return ANSWER
|
|
||||||
'''
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@lmql.query(model=model)
|
|
||||||
async def program(question, max_tokens):
|
|
||||||
'''lmql
|
|
||||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
|
|
||||||
return ANSWER
|
|
||||||
'''
|
|
||||||
|
|
||||||
tasks = [
|
|
||||||
program(
|
|
||||||
question=prompt,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop=stop,
|
|
||||||
max_len=max_len,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
for _ in range(n)
|
|
||||||
]
|
|
||||||
rets = await asyncio.gather(*tasks)
|
|
||||||
return rets if n > 1 else rets[0]
|
|
||||||
|
|
||||||
|
|
||||||
def call_select_lightllm(context, choices, url=None):
|
def call_select_lightllm(context, choices, url=None):
|
||||||
assert url is not None
|
assert url is not None
|
||||||
|
|
||||||
@@ -247,23 +208,6 @@ def call_select_guidance(context, choices, model=None):
|
|||||||
return choices.index(out["answer"])
|
return choices.index(out["answer"])
|
||||||
|
|
||||||
|
|
||||||
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
|
|
||||||
assert model is not None
|
|
||||||
import lmql
|
|
||||||
|
|
||||||
@lmql.query(model=model)
|
|
||||||
async def program(ctx, choices):
|
|
||||||
'''lmql
|
|
||||||
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
|
|
||||||
return ANSWER
|
|
||||||
'''
|
|
||||||
|
|
||||||
answer = await program(
|
|
||||||
ctx=context, choices=choices, temperature=temperature, max_len=max_len
|
|
||||||
)
|
|
||||||
return choices.index(answer)
|
|
||||||
|
|
||||||
|
|
||||||
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
|
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument("--parallel", type=int, default=64)
|
parser.add_argument("--parallel", type=int, default=64)
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
@@ -278,7 +222,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
|
|||||||
"lightllm",
|
"lightllm",
|
||||||
"gserver",
|
"gserver",
|
||||||
"guidance",
|
"guidance",
|
||||||
"lmql",
|
|
||||||
"srt-raw",
|
"srt-raw",
|
||||||
"llama.cpp",
|
"llama.cpp",
|
||||||
],
|
],
|
||||||
@@ -295,7 +238,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
|
|||||||
"vllm": 21000,
|
"vllm": 21000,
|
||||||
"outlines": 21000,
|
"outlines": 21000,
|
||||||
"lightllm": 22000,
|
"lightllm": 22000,
|
||||||
"lmql": 23000,
|
|
||||||
"srt-raw": 30000,
|
"srt-raw": 30000,
|
||||||
"gserver": 9988,
|
"gserver": 9988,
|
||||||
}
|
}
|
||||||
@@ -343,11 +285,6 @@ def _get_call_generate(args: argparse.Namespace):
|
|||||||
call_generate = partial(call_generate_guidance, model=model)
|
call_generate = partial(call_generate_guidance, model=model)
|
||||||
call_generate("Hello,", 1.0, 8, ".")
|
call_generate("Hello,", 1.0, 8, ".")
|
||||||
return call_generate
|
return call_generate
|
||||||
elif args.backend == "lmql":
|
|
||||||
import lmql
|
|
||||||
|
|
||||||
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
|
||||||
return partial(call_generate_lmql, model=model)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid backend: {args.backend}")
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
@@ -365,12 +302,6 @@ def _get_call_select(args: argparse.Namespace):
|
|||||||
|
|
||||||
call_select("Hello,", ["world", "earth"])
|
call_select("Hello,", ["world", "earth"])
|
||||||
return call_select
|
return call_select
|
||||||
|
|
||||||
elif args.backend == "lmql":
|
|
||||||
import lmql
|
|
||||||
|
|
||||||
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
|
||||||
return partial(call_select_lmql, model=model)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid backend: {args.backend}")
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user