Add Python API for clustering (#1385)

This commit is contained in:
Fangjun Kuang
2024-09-30 11:33:15 +08:00
committed by GitHub
parent 70568c2df7
commit b965f14cf0
26 changed files with 326 additions and 15 deletions

View File

@@ -26,11 +26,13 @@ void FastClusteringConfig::Register(ParseOptions *po) {
p.Register("num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then --cluster-thresold is "
"ignored");
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");
p.Register("cluster-threshold", &threshold,
"If --num-clusters is not specified, then it specifies the "
"distance threshold for clustering.");
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
}
bool FastClusteringConfig::Validate() const {

View File

@@ -12,12 +12,23 @@
namespace sherpa_onnx {
struct FastClusteringConfig {
// If greater than 0, then threshold is ignored
// If greater than 0, then threshold is ignored.
//
// We strongly recommend that you set it if you know the number of clusters
// in advance
int32_t num_clusters = -1;
// distance threshold
// distance threshold.
//
// The lower, the more clusters it will generate.
// The higher, the fewer clusters it will generate.
float threshold = 0.5;
FastClusteringConfig() = default;
FastClusteringConfig(int32_t num_clusters, float threshold)
: num_clusters(num_clusters), threshold(threshold) {}
std::string ToString() const;
void Register(ParseOptions *po);

View File

@@ -16,7 +16,7 @@ class FastClustering::Impl {
explicit Impl(const FastClusteringConfig &config) : config_(config) {}
std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols) {
int32_t num_cols) const {
if (num_rows <= 0) {
return {};
}
@@ -77,7 +77,7 @@ FastClustering::FastClustering(const FastClusteringConfig &config)
FastClustering::~FastClustering() = default;
std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows,
int32_t num_cols) {
int32_t num_cols) const {
return impl_->Cluster(features, num_rows, num_cols);
}
} // namespace sherpa_onnx

View File

@@ -32,7 +32,7 @@ class FastClustering {
* matrix.
*/
std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols);
int32_t num_cols) const;
private:
class Impl;