Support Agglomerative clustering. (#1384)
We use the open-source implementation from https://github.com/cdalitz/hclust-cpp
This commit is contained in:
69
sherpa-onnx/csrc/fast-clustering-test.cc
Normal file
69
sherpa-onnx/csrc/fast-clustering-test.cc
Normal file
@@ -0,0 +1,69 @@
|
||||
// sherpa-onnx/csrc/fast-clustering-test.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/fast-clustering.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(FastClustering, TestTwoClusters) {
|
||||
std::vector<float> features = {
|
||||
// point 0
|
||||
0.1,
|
||||
0.1,
|
||||
// point 2
|
||||
0.4,
|
||||
-0.5,
|
||||
// point 3
|
||||
0.6,
|
||||
-0.7,
|
||||
// point 1
|
||||
0.2,
|
||||
0.3,
|
||||
};
|
||||
|
||||
FastClusteringConfig config;
|
||||
config.num_clusters = 2;
|
||||
|
||||
FastClustering clustering(config);
|
||||
auto labels = clustering.Cluster(features.data(), 4, 2);
|
||||
int32_t k = 0;
|
||||
for (auto i : labels) {
|
||||
std::cout << "point " << k << ": label " << i << "\n";
|
||||
++k;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FastClustering, TestClusteringWithThreshold) {
|
||||
std::vector<float> features = {
|
||||
// point 0
|
||||
0.1,
|
||||
0.1,
|
||||
// point 2
|
||||
0.4,
|
||||
-0.5,
|
||||
// point 3
|
||||
0.6,
|
||||
-0.7,
|
||||
// point 1
|
||||
0.2,
|
||||
0.3,
|
||||
};
|
||||
|
||||
FastClusteringConfig config;
|
||||
config.threshold = 0.5;
|
||||
|
||||
FastClustering clustering(config);
|
||||
auto labels = clustering.Cluster(features.data(), 4, 2);
|
||||
int32_t k = 0;
|
||||
for (auto i : labels) {
|
||||
std::cout << "point " << k << ": label " << i << "\n";
|
||||
++k;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
Reference in New Issue
Block a user