change arg name
This commit is contained in:
15
main.py
15
main.py
@@ -12,14 +12,15 @@ def parse_args():
|
|||||||
p = argparse.ArgumentParser(
|
p = argparse.ArgumentParser(
|
||||||
description="Encode sentences with SentenceTransformer and output embeddings & pairwise cosine similarity."
|
description="Encode sentences with SentenceTransformer and output embeddings & pairwise cosine similarity."
|
||||||
)
|
)
|
||||||
p.add_argument("input_json", help="输入文件路径(JSON,形如 ['句子1','句子2', ...])")
|
p.add_argument("--json", help="输入文件路径(JSON,形如 ['句子1','句子2', ...])")
|
||||||
p.add_argument("output_json", help="输出文件路径(JSON)")
|
p.add_argument("--results", help="输出文件路径(JSON)")
|
||||||
p.add_argument("model_path", help="模型路径或模型名,如 BAAI/bge-large-zh-v1.5 或本地目录")
|
p.add_argument("--model", help="模型路径或模型名,如 BAAI/bge-large-zh-v1.5 或本地目录")
|
||||||
p.add_argument("--device", default=None,
|
p.add_argument("--device", default=None,
|
||||||
help="设备:cuda / cpu / npu;默认自动检测(优先 cuda,其次 cpu;也可显式传 npu)")
|
help="设备:cuda / cpu / npu;默认自动检测(优先 cuda,其次 cpu;也可显式传 npu)")
|
||||||
p.add_argument("--batch-size", type=int, default=32, help="encode 批大小,默认 32")
|
p.add_argument("--batch-size", type=int, default=32, help="encode 批大小,默认 32")
|
||||||
p.add_argument("--no-normalize", action="store_true", help="不做 L2 归一化(默认会归一化)")
|
p.add_argument("--no-normalize", action="store_true", help="不做 L2 归一化(默认会归一化)")
|
||||||
return p.parse_args()
|
args, _ = p.parse_known_args()
|
||||||
|
return args
|
||||||
|
|
||||||
def auto_device(user_device: str | None) -> str:
|
def auto_device(user_device: str | None) -> str:
|
||||||
if user_device:
|
if user_device:
|
||||||
@@ -34,9 +35,9 @@ def auto_device(user_device: str | None) -> str:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
inp_path = Path(args.input_json)
|
inp_path = Path(args.json)
|
||||||
out_path = Path(args.output_json)
|
out_path = Path(args.results)
|
||||||
model_path = args.model_path
|
model_path = args.model
|
||||||
device = auto_device(args.device)
|
device = auto_device(args.device)
|
||||||
normalize = not args.no_normalize
|
normalize = not args.no_normalize
|
||||||
|
|
||||||
|
|||||||
3
test.sh
3
test.sh
@@ -1,2 +1 @@
|
|||||||
# python main.py dataset.json output.json /mnt/contest_ceph/zhanghao/models/BAAI/bge-large-zh-v1.5 --device npu
|
python main.py --json dataset.json --results output.json --model /mnt/contest_ceph/zhanghao/models/BAAI/bge-large-zh-v1.5 --device cuda
|
||||||
python main.py dataset.json output.json /mnt/contest_ceph/zhanghao/models/BAAI/bge-large-zh-v1.5 --device cuda
|
|
||||||
|
|||||||
Reference in New Issue
Block a user