Files

102 lines
3.4 KiB
Python
Raw Permalink Normal View History

2025-09-03 10:43:22 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import time
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
2025-09-03 11:56:13 +08:00
import torch
2025-09-03 10:43:22 +08:00
def parse_args():
p = argparse.ArgumentParser(
description="Encode sentences with SentenceTransformer and output embeddings & pairwise cosine similarity."
)
2025-09-03 11:09:51 +08:00
p.add_argument("--json", help="输入文件路径JSON形如 ['句子1','句子2', ...]")
p.add_argument("--results", help="输出文件路径JSON")
p.add_argument("--model", help="模型路径或模型名,如 BAAI/bge-large-zh-v1.5 或本地目录")
2025-09-03 10:43:22 +08:00
p.add_argument("--device", default=None,
help="设备cuda / cpu / npu默认自动检测优先 cuda其次 cpu也可显式传 npu")
p.add_argument("--batch-size", type=int, default=32, help="encode 批大小,默认 32")
p.add_argument("--no-normalize", action="store_true", help="不做 L2 归一化(默认会归一化)")
2025-09-03 11:09:51 +08:00
args, _ = p.parse_known_args()
return args
2025-09-03 10:43:22 +08:00
def auto_device(user_device: str | None) -> str:
if user_device:
2025-09-03 11:56:13 +08:00
if user_device == "cuda" and not torch.cuda.is_available():
if torch.npu.is_available():
return "npu"
2025-09-03 10:43:22 +08:00
return user_device
try:
if torch.cuda.is_available():
return "cuda"
2025-09-03 11:56:13 +08:00
if torch.npu.is_available():
return "npu"
2025-09-03 10:43:22 +08:00
except Exception:
pass
return "cpu"
def main():
args = parse_args()
2025-09-03 11:09:51 +08:00
inp_path = Path(args.json)
out_path = Path(args.results)
model_path = args.model
2025-09-03 10:43:22 +08:00
device = auto_device(args.device)
normalize = not args.no_normalize
# 读取输入
with inp_path.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("输入 JSON 必须是数组格式,如:['句子1', '句子2', ...]")
sentences = [str(x) for x in data]
# 加载模型
model = SentenceTransformer(model_path, device=device)
# 编码并计时
t0 = time.time()
embeddings = model.encode(
sentences,
batch_size=args.batch_size,
normalize_embeddings=normalize,
convert_to_numpy=True,
device=device
)
encode_time = time.time() - t0
# 若未归一化,则计算相似度前先做归一化(保证 similarity 为余弦相似度)
if not normalize:
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
# 两两相似度(余弦)——已归一化则点积即余弦
similarity = embeddings @ embeddings.T
2025-09-03 12:06:44 +08:00
avg_latency = encode_time / len(sentences) if sentences else 0
2025-09-03 10:43:22 +08:00
# 组织输出
result = {
"model_path": model_path,
"device": device,
"count": len(sentences),
"dim": int(embeddings.shape[1]) if len(embeddings.shape) == 2 else None,
2025-09-03 12:06:44 +08:00
"total_elapsed_seconds": round(float(encode_time), 6),
"avg_latency": avg_latency,
2025-09-03 10:43:22 +08:00
"sentences": sentences,
"embeddings": embeddings.tolist(), # [N, D]
"similarity": similarity.tolist() # [N, N]
}
# 保存
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ Done. Saved to: {out_path}")
if __name__ == "__main__":
main()