Files
2025-09-03 12:06:44 +08:00

102 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import time
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
def parse_args():
p = argparse.ArgumentParser(
description="Encode sentences with SentenceTransformer and output embeddings & pairwise cosine similarity."
)
p.add_argument("--json", help="输入文件路径JSON形如 ['句子1','句子2', ...]")
p.add_argument("--results", help="输出文件路径JSON")
p.add_argument("--model", help="模型路径或模型名,如 BAAI/bge-large-zh-v1.5 或本地目录")
p.add_argument("--device", default=None,
help="设备cuda / cpu / npu默认自动检测优先 cuda其次 cpu也可显式传 npu")
p.add_argument("--batch-size", type=int, default=32, help="encode 批大小,默认 32")
p.add_argument("--no-normalize", action="store_true", help="不做 L2 归一化(默认会归一化)")
args, _ = p.parse_known_args()
return args
def auto_device(user_device: str | None) -> str:
if user_device:
if user_device == "cuda" and not torch.cuda.is_available():
if torch.npu.is_available():
return "npu"
return user_device
try:
if torch.cuda.is_available():
return "cuda"
if torch.npu.is_available():
return "npu"
except Exception:
pass
return "cpu"
def main():
args = parse_args()
inp_path = Path(args.json)
out_path = Path(args.results)
model_path = args.model
device = auto_device(args.device)
normalize = not args.no_normalize
# 读取输入
with inp_path.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("输入 JSON 必须是数组格式,如:['句子1', '句子2', ...]")
sentences = [str(x) for x in data]
# 加载模型
model = SentenceTransformer(model_path, device=device)
# 编码并计时
t0 = time.time()
embeddings = model.encode(
sentences,
batch_size=args.batch_size,
normalize_embeddings=normalize,
convert_to_numpy=True,
device=device
)
encode_time = time.time() - t0
# 若未归一化,则计算相似度前先做归一化(保证 similarity 为余弦相似度)
if not normalize:
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
# 两两相似度(余弦)——已归一化则点积即余弦
similarity = embeddings @ embeddings.T
avg_latency = encode_time / len(sentences) if sentences else 0
# 组织输出
result = {
"model_path": model_path,
"device": device,
"count": len(sentences),
"dim": int(embeddings.shape[1]) if len(embeddings.shape) == 2 else None,
"total_elapsed_seconds": round(float(encode_time), 6),
"avg_latency": avg_latency,
"sentences": sentences,
"embeddings": embeddings.tolist(), # [N, D]
"similarity": similarity.tolist() # [N, N]
}
# 保存
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ Done. Saved to: {out_path}")
if __name__ == "__main__":
main()