Update benchmark scripts (#8)
This commit is contained in:
@@ -351,7 +351,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=False
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
@@ -93,7 +93,8 @@ class ServerArgs:
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="Model mode: [flashinfer, no-cache, aggressive-new-fill]",
|
||||
choices=["flashinfer", "no-cache"],
|
||||
help="Model mode: [flashinfer, no-cache]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-heuristic",
|
||||
|
||||
@@ -99,7 +99,7 @@ def call_select_vllm(context, choices, url):
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
scores.append(res.json()["prompt_score"])
|
||||
scores.append(res.json().get("prompt_score", 0))
|
||||
return np.argmax(scores)
|
||||
|
||||
"""
|
||||
@@ -112,7 +112,7 @@ def call_select_vllm(context, choices, url):
|
||||
|
||||
|
||||
def add_common_other_args_and_parse(parser):
|
||||
parser.add_argument("--parallel", type=int, default=96)
|
||||
parser.add_argument("--parallel", type=int, default=64)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=None)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user