Add Python API for clustering (#1385)
This commit is contained in:
@@ -26,11 +26,13 @@ void FastClusteringConfig::Register(ParseOptions *po) {
|
||||
|
||||
p.Register("num-clusters", &num_clusters,
|
||||
"Number of cluster. If greater than 0, then --cluster-thresold is "
|
||||
"ignored");
|
||||
"ignored. Please provide it if you know the actual number of "
|
||||
"clusters in advance.");
|
||||
|
||||
p.Register("cluster-threshold", &threshold,
|
||||
"If --num-clusters is not specified, then it specifies the "
|
||||
"distance threshold for clustering.");
|
||||
"distance threshold for clustering. smaller value -> more "
|
||||
"clusters. larger value -> fewer clusters");
|
||||
}
|
||||
|
||||
bool FastClusteringConfig::Validate() const {
|
||||
|
||||
@@ -12,12 +12,23 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct FastClusteringConfig {
|
||||
// If greater than 0, then threshold is ignored
|
||||
// If greater than 0, then threshold is ignored.
|
||||
//
|
||||
// We strongly recommend that you set it if you know the number of clusters
|
||||
// in advance
|
||||
int32_t num_clusters = -1;
|
||||
|
||||
// distance threshold
|
||||
// distance threshold.
|
||||
//
|
||||
// The lower, the more clusters it will generate.
|
||||
// The higher, the fewer clusters it will generate.
|
||||
float threshold = 0.5;
|
||||
|
||||
FastClusteringConfig() = default;
|
||||
|
||||
FastClusteringConfig(int32_t num_clusters, float threshold)
|
||||
: num_clusters(num_clusters), threshold(threshold) {}
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
@@ -16,7 +16,7 @@ class FastClustering::Impl {
|
||||
explicit Impl(const FastClusteringConfig &config) : config_(config) {}
|
||||
|
||||
std::vector<int32_t> Cluster(float *features, int32_t num_rows,
|
||||
int32_t num_cols) {
|
||||
int32_t num_cols) const {
|
||||
if (num_rows <= 0) {
|
||||
return {};
|
||||
}
|
||||
@@ -77,7 +77,7 @@ FastClustering::FastClustering(const FastClusteringConfig &config)
|
||||
FastClustering::~FastClustering() = default;
|
||||
|
||||
std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows,
|
||||
int32_t num_cols) {
|
||||
int32_t num_cols) const {
|
||||
return impl_->Cluster(features, num_rows, num_cols);
|
||||
}
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -32,7 +32,7 @@ class FastClustering {
|
||||
* matrix.
|
||||
*/
|
||||
std::vector<int32_t> Cluster(float *features, int32_t num_rows,
|
||||
int32_t num_cols);
|
||||
int32_t num_cols) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
|
||||
Reference in New Issue
Block a user