#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import json import time from pathlib import Path import numpy as np from sentence_transformers import SentenceTransformer import torch def parse_args(): p = argparse.ArgumentParser( description="Encode sentences with SentenceTransformer and output embeddings & pairwise cosine similarity." ) p.add_argument("--json", help="输入文件路径(JSON,形如 ['句子1','句子2', ...])") p.add_argument("--results", help="输出文件路径(JSON)") p.add_argument("--model", help="模型路径或模型名,如 BAAI/bge-large-zh-v1.5 或本地目录") p.add_argument("--device", default=None, help="设备:cuda / cpu / npu;默认自动检测(优先 cuda,其次 cpu;也可显式传 npu)") p.add_argument("--batch-size", type=int, default=32, help="encode 批大小,默认 32") p.add_argument("--no-normalize", action="store_true", help="不做 L2 归一化(默认会归一化)") args, _ = p.parse_known_args() return args def auto_device(user_device: str | None) -> str: if user_device: if user_device == "cuda" and not torch.cuda.is_available(): if torch.npu.is_available(): return "npu" return user_device try: if torch.cuda.is_available(): return "cuda" if torch.npu.is_available(): return "npu" except Exception: pass return "cpu" def main(): args = parse_args() inp_path = Path(args.json) out_path = Path(args.results) model_path = args.model device = auto_device(args.device) normalize = not args.no_normalize # 读取输入 with inp_path.open("r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("输入 JSON 必须是数组格式,如:['句子1', '句子2', ...]") sentences = [str(x) for x in data] # 加载模型 model = SentenceTransformer(model_path, device=device) # 编码并计时 t0 = time.time() embeddings = model.encode( sentences, batch_size=args.batch_size, normalize_embeddings=normalize, convert_to_numpy=True, device=device ) encode_time = time.time() - t0 # 若未归一化,则计算相似度前先做归一化(保证 similarity 为余弦相似度) if not normalize: norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 embeddings = embeddings / norms # 两两相似度(余弦)——已归一化则点积即余弦 similarity = embeddings @ embeddings.T avg_latency = encode_time / len(sentences) if sentences else 0 # 组织输出 result = { "model_path": model_path, "device": device, "count": len(sentences), "dim": int(embeddings.shape[1]) if len(embeddings.shape) == 2 else None, "total_elapsed_seconds": round(float(encode_time), 6), "avg_latency": avg_latency, "sentences": sentences, "embeddings": embeddings.tolist(), # [N, D] "similarity": similarity.tolist() # [N, N] } # 保存 out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) print(f"✅ Done. Saved to: {out_path}") if __name__ == "__main__": main()