From 66156a20fa0dd0c31504f79dc408b3fd11823e99 Mon Sep 17 00:00:00 2001 From: ZHANG Hao Date: Wed, 3 Sep 2025 11:09:51 +0800 Subject: [PATCH] change arg name --- main.py | 15 ++++++++------- test.sh | 3 +-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 40e0eab..b0cc860 100644 --- a/main.py +++ b/main.py @@ -12,14 +12,15 @@ def parse_args(): p = argparse.ArgumentParser( description="Encode sentences with SentenceTransformer and output embeddings & pairwise cosine similarity." ) - p.add_argument("input_json", help="输入文件路径(JSON,形如 ['句子1','句子2', ...])") - p.add_argument("output_json", help="输出文件路径(JSON)") - p.add_argument("model_path", help="模型路径或模型名,如 BAAI/bge-large-zh-v1.5 或本地目录") + p.add_argument("--json", help="输入文件路径(JSON,形如 ['句子1','句子2', ...])") + p.add_argument("--results", help="输出文件路径(JSON)") + p.add_argument("--model", help="模型路径或模型名,如 BAAI/bge-large-zh-v1.5 或本地目录") p.add_argument("--device", default=None, help="设备:cuda / cpu / npu;默认自动检测(优先 cuda,其次 cpu;也可显式传 npu)") p.add_argument("--batch-size", type=int, default=32, help="encode 批大小,默认 32") p.add_argument("--no-normalize", action="store_true", help="不做 L2 归一化(默认会归一化)") - return p.parse_args() + args, _ = p.parse_known_args() + return args def auto_device(user_device: str | None) -> str: if user_device: @@ -34,9 +35,9 @@ def auto_device(user_device: str | None) -> str: def main(): args = parse_args() - inp_path = Path(args.input_json) - out_path = Path(args.output_json) - model_path = args.model_path + inp_path = Path(args.json) + out_path = Path(args.results) + model_path = args.model device = auto_device(args.device) normalize = not args.no_normalize diff --git a/test.sh b/test.sh index 772b756..76be5d0 100755 --- a/test.sh +++ b/test.sh @@ -1,2 +1 @@ -# python main.py dataset.json output.json /mnt/contest_ceph/zhanghao/models/BAAI/bge-large-zh-v1.5 --device npu -python main.py dataset.json output.json /mnt/contest_ceph/zhanghao/models/BAAI/bge-large-zh-v1.5 --device cuda +python main.py --json dataset.json --results output.json --model /mnt/contest_ceph/zhanghao/models/BAAI/bge-large-zh-v1.5 --device cuda