decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
This commit is contained in:
@@ -19,6 +19,8 @@ set(sources
|
||||
features.cc
|
||||
file-utils.cc
|
||||
hypothesis.cc
|
||||
keyword-spotter-impl.cc
|
||||
keyword-spotter.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-ctc-fst-decoder.cc
|
||||
offline-ctc-greedy-search-decoder.cc
|
||||
@@ -87,6 +89,7 @@ set(sources
|
||||
stack.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
transducer-keyword-decoder.cc
|
||||
transpose.cc
|
||||
unbind.cc
|
||||
utils.cc
|
||||
@@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
|
||||
endif()
|
||||
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
|
||||
set(main_exes
|
||||
sherpa-onnx
|
||||
sherpa-onnx-keyword-spotter
|
||||
sherpa-onnx-offline
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-tts
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <string>
|
||||
@@ -15,27 +16,25 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
static void TestHelper(const std::map<std::string, float> &queries, float score,
|
||||
bool strict_mode) {
|
||||
std::vector<std::string> contexts_str(
|
||||
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
std::vector<float> scores;
|
||||
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||
scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
|
||||
}
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
auto context_graph = ContextGraph(contexts, 1, scores);
|
||||
|
||||
for (const auto &iter : queries) {
|
||||
float total_scores = 0;
|
||||
auto state = context_graph.Root();
|
||||
for (auto q : iter.first) {
|
||||
auto res = context_graph.ForwardOneStep(state, q);
|
||||
total_scores += res.first;
|
||||
state = res.second;
|
||||
auto res = context_graph.ForwardOneStep(state, q, strict_mode);
|
||||
total_scores += std::get<0>(res);
|
||||
state = std::get<1>(res);
|
||||
}
|
||||
auto res = context_graph.Finalize(state);
|
||||
EXPECT_EQ(res.second->token, -1);
|
||||
@@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
TestHelper(queries, 0, true);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestBasicNonStrict) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3},
|
||||
{"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}};
|
||||
TestHelper(queries, 0, false);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestCustomize) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18},
|
||||
{"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5},
|
||||
{"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}};
|
||||
TestHelper(queries, 5, true);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestCustomizeNonStrict) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84},
|
||||
{"SHED", 10}, {"SHELF", 10}, {"HELL", 5},
|
||||
{"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}};
|
||||
TestHelper(queries, 5, false);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, Benchmark) {
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
|
||||
@@ -4,22 +4,59 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
void ContextGraph::Build(
|
||||
const std::vector<std::vector<int32_t>> &token_ids) const {
|
||||
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
const std::vector<float> &scores,
|
||||
const std::vector<std::string> &phrases,
|
||||
const std::vector<float> &ac_thresholds) const {
|
||||
if (!scores.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
|
||||
}
|
||||
if (!phrases.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
|
||||
}
|
||||
if (!ac_thresholds.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
|
||||
}
|
||||
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
||||
auto node = root_.get();
|
||||
float score = scores.empty() ? 0.0f : scores[i];
|
||||
score = score == 0.0f ? context_score_ : score;
|
||||
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
|
||||
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
|
||||
std::string phrase = phrases.empty() ? std::string() : phrases[i];
|
||||
|
||||
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
||||
int32_t token = token_ids[i][j];
|
||||
if (0 == node->next.count(token)) {
|
||||
bool is_end = j == token_ids[i].size() - 1;
|
||||
node->next[token] = std::make_unique<ContextState>(
|
||||
token, context_score_, node->node_score + context_score_,
|
||||
is_end ? node->node_score + context_score_ : 0, is_end);
|
||||
token, score, node->node_score + score,
|
||||
is_end ? node->node_score + score : 0, j + 1,
|
||||
is_end ? ac_threshold : 0.0f, is_end,
|
||||
is_end ? phrase : std::string());
|
||||
} else {
|
||||
float token_score = std::max(score, node->next[token]->token_score);
|
||||
node->next[token]->token_score = token_score;
|
||||
float node_score = node->node_score + token_score;
|
||||
node->next[token]->node_score = node_score;
|
||||
bool is_end =
|
||||
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
|
||||
node->next[token]->output_score = is_end ? node_score : 0.0f;
|
||||
node->next[token]->is_end = is_end;
|
||||
if (j == token_ids[i].size() - 1) {
|
||||
node->next[token]->phrase = phrase;
|
||||
node->next[token]->ac_threshold = ac_threshold;
|
||||
}
|
||||
}
|
||||
node = node->next[token].get();
|
||||
}
|
||||
@@ -27,8 +64,9 @@ void ContextGraph::Build(
|
||||
FillFailOutput();
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
const ContextState *state, int32_t token) const {
|
||||
std::tuple<float, const ContextState *, const ContextState *>
|
||||
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
|
||||
bool strict_mode /*= true*/) const {
|
||||
const ContextState *node;
|
||||
float score;
|
||||
if (1 == state->next.count(token)) {
|
||||
@@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
}
|
||||
score = node->node_score - state->node_score;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_CHECK(nullptr != node);
|
||||
return std::make_pair(score + node->output_score, node);
|
||||
|
||||
const ContextState *matched_node =
|
||||
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
|
||||
|
||||
if (!strict_mode && node->output_score != 0) {
|
||||
SHERPA_ONNX_CHECK(nullptr != matched_node);
|
||||
float output_score =
|
||||
node->is_end ? node->node_score
|
||||
: (node->output != nullptr ? node->output->node_score
|
||||
: node->node_score);
|
||||
return std::make_tuple(score + output_score - node->node_score, root_.get(),
|
||||
matched_node);
|
||||
}
|
||||
return std::make_tuple(score + node->output_score, node, matched_node);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
@@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
return std::make_pair(score, root_.get());
|
||||
}
|
||||
|
||||
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
|
||||
const ContextState *state) const {
|
||||
bool status = false;
|
||||
const ContextState *node = nullptr;
|
||||
if (state->is_end) {
|
||||
status = true;
|
||||
node = state;
|
||||
} else {
|
||||
if (state->output != nullptr) {
|
||||
status = true;
|
||||
node = state->output;
|
||||
}
|
||||
}
|
||||
return std::make_pair(status, node);
|
||||
}
|
||||
|
||||
void ContextGraph::FillFailOutput() const {
|
||||
std::queue<const ContextState *> node_queue;
|
||||
for (auto &kv : root_->next) {
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -22,34 +24,55 @@ struct ContextState {
|
||||
float token_score;
|
||||
float node_score;
|
||||
float output_score;
|
||||
int32_t level;
|
||||
float ac_threshold;
|
||||
bool is_end;
|
||||
std::string phrase;
|
||||
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
||||
const ContextState *fail = nullptr;
|
||||
const ContextState *output = nullptr;
|
||||
|
||||
ContextState() = default;
|
||||
ContextState(int32_t token, float token_score, float node_score,
|
||||
float output_score, bool is_end)
|
||||
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
|
||||
bool is_end = false, const std::string &phrase = {})
|
||||
: token(token),
|
||||
token_score(token_score),
|
||||
node_score(node_score),
|
||||
output_score(output_score),
|
||||
is_end(is_end) {}
|
||||
level(level),
|
||||
ac_threshold(ac_threshold),
|
||||
is_end(is_end),
|
||||
phrase(phrase) {}
|
||||
};
|
||||
|
||||
class ContextGraph {
|
||||
public:
|
||||
ContextGraph() = default;
|
||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
float context_score)
|
||||
: context_score_(context_score) {
|
||||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
|
||||
float context_score, float ac_threshold,
|
||||
const std::vector<float> &scores = {},
|
||||
const std::vector<std::string> &phrases = {},
|
||||
const std::vector<float> &ac_thresholds = {})
|
||||
: context_score_(context_score), ac_threshold_(ac_threshold) {
|
||||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
|
||||
root_->fail = root_.get();
|
||||
Build(token_ids);
|
||||
Build(token_ids, scores, phrases, ac_thresholds);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ForwardOneStep(
|
||||
const ContextState *state, int32_t token_id) const;
|
||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
float context_score, const std::vector<float> &scores = {},
|
||||
const std::vector<std::string> &phrases = {})
|
||||
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
|
||||
std::vector<float>()) {}
|
||||
|
||||
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
|
||||
const ContextState *state, int32_t token_id,
|
||||
bool strict_mode = true) const;
|
||||
|
||||
std::pair<bool, const ContextState *> IsMatched(
|
||||
const ContextState *state) const;
|
||||
|
||||
std::pair<float, const ContextState *> Finalize(
|
||||
const ContextState *state) const;
|
||||
|
||||
@@ -57,8 +80,12 @@ class ContextGraph {
|
||||
|
||||
private:
|
||||
float context_score_;
|
||||
float ac_threshold_;
|
||||
std::unique_ptr<ContextState> root_;
|
||||
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
|
||||
void Build(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
const std::vector<float> &scores,
|
||||
const std::vector<std::string> &phrases,
|
||||
const std::vector<float> &ac_thresholds) const;
|
||||
void FillFailOutput() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -28,6 +28,10 @@ struct Hypothesis {
|
||||
// on which ys[i] is decoded.
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
// The acoustic probability for each token in ys.
|
||||
// Only used for keyword spotting task.
|
||||
std::vector<float> ys_probs;
|
||||
|
||||
// The total score of ys in log space.
|
||||
// It contains only acoustic scores
|
||||
double log_prob = 0;
|
||||
|
||||
33
sherpa-onnx/csrc/keyword-spotter-impl.cc
Normal file
33
sherpa-onnx/csrc/keyword-spotter-impl.cc
Normal file
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter-impl.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
|
||||
const KeywordSpotterConfig &config) {
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
return std::make_unique<KeywordSpotterTransducerImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a model");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
|
||||
AAssetManager *mgr, const KeywordSpotterConfig &config) {
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a model");
|
||||
exit(-1);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
48
sherpa-onnx/csrc/keyword-spotter-impl.h
Normal file
48
sherpa-onnx/csrc/keyword-spotter-impl.h
Normal file
@@ -0,0 +1,48 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter-impl.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class KeywordSpotterImpl {
|
||||
public:
|
||||
static std::unique_ptr<KeywordSpotterImpl> Create(
|
||||
const KeywordSpotterConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<KeywordSpotterImpl> Create(
|
||||
AAssetManager *mgr, const KeywordSpotterConfig &config);
|
||||
#endif
|
||||
|
||||
virtual ~KeywordSpotterImpl() = default;
|
||||
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::string &keywords) const = 0;
|
||||
|
||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||
|
||||
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||
|
||||
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||
323
sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Normal file
323
sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Normal file
@@ -0,0 +1,323 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static KeywordResult Convert(const TransducerKeywordResult &src,
|
||||
const SymbolTable &sym_table, float frame_shift_ms,
|
||||
int32_t subsampling_factor,
|
||||
int32_t frames_since_start) {
|
||||
KeywordResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
r.keyword = src.keyword;
|
||||
bool from_tokens = src.keyword.empty();
|
||||
|
||||
for (auto i : src.tokens) {
|
||||
auto sym = sym_table[i];
|
||||
if (from_tokens) {
|
||||
r.keyword.append(sym);
|
||||
}
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
if (from_tokens && r.keyword.size()) {
|
||||
r.keyword = r.keyword.substr(1);
|
||||
}
|
||||
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
public:
|
||||
explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
InitKeywords();
|
||||
|
||||
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
|
||||
unk_id_);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
KeywordSpotterTransducerImpl(AAssetManager *mgr,
|
||||
const KeywordSpotterConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.model_config.tokens) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
InitKeywords(mgr);
|
||||
|
||||
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
|
||||
unk_id_);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph_);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::string &keywords) const override {
|
||||
auto kws = std::regex_replace(keywords, std::regex("/"), "\n");
|
||||
std::istringstream is(kws);
|
||||
|
||||
std::vector<std::vector<int32_t>> current_ids;
|
||||
std::vector<std::string> current_kws;
|
||||
std::vector<float> current_scores;
|
||||
std::vector<float> current_thresholds;
|
||||
|
||||
if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores,
|
||||
¤t_thresholds)) {
|
||||
SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t num_kws = current_ids.size();
|
||||
int32_t num_default_kws = keywords_id_.size();
|
||||
|
||||
current_ids.insert(current_ids.end(), keywords_id_.begin(),
|
||||
keywords_id_.end());
|
||||
|
||||
if (!current_kws.empty() && !keywords_.empty()) {
|
||||
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
|
||||
} else if (!current_kws.empty() && keywords_.empty()) {
|
||||
current_kws.insert(current_kws.end(), num_default_kws, std::string());
|
||||
} else if (current_kws.empty() && !keywords_.empty()) {
|
||||
current_kws.insert(current_kws.end(), num_kws, std::string());
|
||||
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
if (!current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else if (!current_scores.empty() && boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_default_kws,
|
||||
config_.keywords_score);
|
||||
} else if (current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_kws,
|
||||
config_.keywords_score);
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
if (!current_thresholds.empty() && !thresholds_.empty()) {
|
||||
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
|
||||
thresholds_.end());
|
||||
} else if (!current_thresholds.empty() && thresholds_.empty()) {
|
||||
current_thresholds.insert(current_thresholds.end(), num_default_kws,
|
||||
config_.keywords_threshold);
|
||||
} else if (current_thresholds.empty() && !thresholds_.empty()) {
|
||||
current_thresholds.insert(current_thresholds.end(), num_kws,
|
||||
config_.keywords_threshold);
|
||||
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
|
||||
thresholds_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
auto keywords_graph = std::make_shared<ContextGraph>(
|
||||
current_ids, config_.keywords_score, config_.keywords_threshold,
|
||||
current_scores, current_kws, current_thresholds);
|
||||
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const override {
|
||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
int32_t feature_dim = ss[0]->FeatureDim();
|
||||
|
||||
std::vector<TransducerKeywordResult> results(n);
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
std::vector<int64_t> all_processed_frames(n);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr);
|
||||
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||
|
||||
// Question: should num_processed_frames include chunk_shift?
|
||||
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
std::copy(features.begin(), features.end(),
|
||||
features_vec.data() + i * chunk_size * feature_dim);
|
||||
|
||||
results[i] = std::move(ss[i]->GetKeywordResult());
|
||||
states_vec[i] = std::move(ss[i]->GetStates());
|
||||
all_processed_frames[i] = num_processed_frames;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||
features_vec.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
|
||||
std::array<int64_t, 1> processed_frames_shape{
|
||||
static_cast<int64_t>(all_processed_frames.size())};
|
||||
|
||||
Ort::Value processed_frames = Ort::Value::CreateTensor(
|
||||
memory_info, all_processed_frames.data(), all_processed_frames.size(),
|
||||
processed_frames_shape.data(), processed_frames_shape.size());
|
||||
|
||||
auto states = model_->StackStates(states_vec);
|
||||
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||
std::move(processed_frames));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), ss, &results);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(pair.second);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
ss[i]->SetKeywordResult(results[i]);
|
||||
ss[i]->SetStates(std::move(next_states[i]));
|
||||
}
|
||||
}
|
||||
|
||||
KeywordResult GetResult(OnlineStream *s) const override {
|
||||
TransducerKeywordResult decoder_result = s->GetKeywordResult(true);
|
||||
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetNumFramesSinceStart());
|
||||
}
|
||||
|
||||
private:
|
||||
void InitKeywords(std::istream &is) {
|
||||
if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_,
|
||||
&thresholds_)) {
|
||||
SHERPA_ONNX_LOGE("Encode keywords failed.");
|
||||
exit(-1);
|
||||
}
|
||||
keywords_graph_ = std::make_shared<ContextGraph>(
|
||||
keywords_id_, config_.keywords_score, config_.keywords_threshold,
|
||||
boost_scores_, keywords_, thresholds_);
|
||||
}
|
||||
|
||||
void InitKeywords() {
|
||||
// each line in keywords_file contains space-separated words
|
||||
|
||||
std::ifstream is(config_.keywords_file);
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
|
||||
config_.keywords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
InitKeywords(is);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
void InitKeywords(AAssetManager *mgr) {
|
||||
// each line in keywords_file contains space-separated words
|
||||
|
||||
auto buf = ReadFile(mgr, config_.keywords_file);
|
||||
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
|
||||
config_.keywords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
InitKeywords(is);
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
SHERPA_ONNX_CHECK_EQ(r.hyps.size(), 1);
|
||||
|
||||
SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr);
|
||||
r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root();
|
||||
|
||||
stream->SetKeywordResult(r);
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
private:
|
||||
KeywordSpotterConfig config_;
|
||||
std::vector<std::vector<int32_t>> keywords_id_;
|
||||
std::vector<float> boost_scores_;
|
||||
std::vector<float> thresholds_;
|
||||
std::vector<std::string> keywords_;
|
||||
ContextGraphPtr keywords_graph_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<TransducerKeywordDecoder> decoder_;
|
||||
SymbolTable sym_;
|
||||
int32_t unk_id_ = -1;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||
152
sherpa-onnx/csrc/keyword-spotter.cc
Normal file
152
sherpa-onnx/csrc/keyword-spotter.cc
Normal file
@@ -0,0 +1,152 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string KeywordResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
|
||||
os << "\"keyword\""
|
||||
<< ": ";
|
||||
os << "\"" << keyword << "\""
|
||||
<< ", ";
|
||||
|
||||
os << "\""
|
||||
<< "timestamps"
|
||||
<< "\""
|
||||
<< ": ";
|
||||
os << "[";
|
||||
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
os << "\""
|
||||
<< "tokens"
|
||||
<< "\""
|
||||
<< ":";
|
||||
os << "[";
|
||||
|
||||
sep = "";
|
||||
auto oldFlags = os.flags();
|
||||
for (const auto &t : tokens) {
|
||||
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
|
||||
os << sep << "\""
|
||||
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
|
||||
<< ">"
|
||||
<< "\"";
|
||||
os.flags(oldFlags);
|
||||
} else {
|
||||
os << sep << "\"" << t << "\"";
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
os << "}";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void KeywordSpotterConfig::Register(ParseOptions *po) {
|
||||
feat_config.Register(po);
|
||||
model_config.Register(po);
|
||||
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("num-trailing-blanks", &num_trailing_blanks,
|
||||
"The number of trailing blanks should have after the keyword.");
|
||||
po->Register("keywords-score", &keywords_score,
|
||||
"The bonus score for each token in context word/phrase.");
|
||||
po->Register("keywords-threshold", &keywords_threshold,
|
||||
"The acoustic threshold (probability) to trigger the keywords.");
|
||||
po->Register(
|
||||
"keywords-file", &keywords_file,
|
||||
"The file containing keywords, one word/phrase per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
}
|
||||
|
||||
bool KeywordSpotterConfig::Validate() const {
|
||||
if (keywords_file.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --keywords-file.");
|
||||
return false;
|
||||
}
|
||||
if (!std::ifstream(keywords_file.c_str()).good()) {
|
||||
SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
std::string KeywordSpotterConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "KeywordSpotterConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "num_trailing_blanks=" << num_trailing_blanks << ", ";
|
||||
os << "keywords_score=" << keywords_score << ", ";
|
||||
os << "keywords_threshold=" << keywords_threshold << ", ";
|
||||
os << "keywords_file=\"" << keywords_file << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config)
|
||||
: impl_(KeywordSpotterImpl::Create(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
KeywordSpotter::KeywordSpotter(AAssetManager *mgr,
|
||||
const KeywordSpotterConfig &config)
|
||||
: impl_(KeywordSpotterImpl::Create(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
KeywordSpotter::~KeywordSpotter() = default;
|
||||
|
||||
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream(
|
||||
const std::string &keywords) const {
|
||||
return impl_->CreateStream(keywords);
|
||||
}
|
||||
|
||||
bool KeywordSpotter::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const {
|
||||
return impl_->GetResult(s);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
148
sherpa-onnx/csrc/keyword-spotter.h
Normal file
148
sherpa-onnx/csrc/keyword-spotter.h
Normal file
@@ -0,0 +1,148 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct KeywordResult {
|
||||
/// The triggered keyword.
|
||||
/// For English, it consists of space separated words.
|
||||
/// For Chinese, it consists of Chinese words without spaces.
|
||||
/// Example 1: "hello world"
|
||||
/// Example 2: "你好世界"
|
||||
std::string keyword;
|
||||
|
||||
/// Decoded results at the token level.
|
||||
/// For instance, for BPE-based models it consists of a list of BPE tokens.
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
/// timestamps.size() == tokens.size()
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
|
||||
/// Starting time of this segment.
|
||||
/// When an endpoint is detected, it will change
|
||||
float start_time = 0;
|
||||
|
||||
/** Return a json string.
|
||||
*
|
||||
* The returned string contains:
|
||||
* {
|
||||
* "keyword": "The triggered keyword",
|
||||
* "tokens": [x, x, x],
|
||||
* "timestamps": [x, x, x],
|
||||
* "start_time": x,
|
||||
* }
|
||||
*/
|
||||
std::string AsJsonString() const;
|
||||
};
|
||||
|
||||
struct KeywordSpotterConfig {
|
||||
FeatureExtractorConfig feat_config;
|
||||
OnlineModelConfig model_config;
|
||||
|
||||
int32_t max_active_paths = 4;
|
||||
|
||||
int32_t num_trailing_blanks = 1;
|
||||
|
||||
float keywords_score = 1.0;
|
||||
|
||||
float keywords_threshold = 0.25;
|
||||
|
||||
std::string keywords_file;
|
||||
|
||||
KeywordSpotterConfig() = default;
|
||||
|
||||
KeywordSpotterConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config,
|
||||
int32_t max_active_paths, int32_t num_trailing_blanks,
|
||||
float keywords_score, float keywords_threshold,
|
||||
const std::string &keywords_file)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
max_active_paths(max_active_paths),
|
||||
num_trailing_blanks(num_trailing_blanks),
|
||||
keywords_score(keywords_score),
|
||||
keywords_threshold(keywords_threshold),
|
||||
keywords_file(keywords_file) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class KeywordSpotterImpl;
|
||||
|
||||
class KeywordSpotter {
|
||||
public:
|
||||
explicit KeywordSpotter(const KeywordSpotterConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config);
|
||||
#endif
|
||||
|
||||
~KeywordSpotter();
|
||||
|
||||
/** Create a stream for decoding.
|
||||
*
|
||||
*/
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
/** Create a stream for decoding.
|
||||
*
|
||||
* @param The keywords for this string, it might contain several keywords,
|
||||
* the keywords are separated by "/". In each of the keywords, there
|
||||
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
|
||||
* For example, keywords I LOVE YOU and HELLO WORLD, looks like:
|
||||
*
|
||||
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
|
||||
*/
|
||||
std::unique_ptr<OnlineStream> CreateStream(const std::string &keywords) const;
|
||||
|
||||
/**
|
||||
* Return true if the given stream has enough frames for decoding.
|
||||
* Return false otherwise
|
||||
*/
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
/** Decode a single stream. */
|
||||
void DecodeStream(OnlineStream *s) const {
|
||||
OnlineStream *ss[1] = {s};
|
||||
DecodeStreams(ss, 1);
|
||||
}
|
||||
|
||||
/** Decode multiple streams in parallel
|
||||
*
|
||||
* @param ss Pointer array containing streams to be decoded.
|
||||
* @param n Number of streams in `ss`.
|
||||
*/
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const;
|
||||
|
||||
KeywordResult GetResult(OnlineStream *s) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<KeywordSpotterImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||
@@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
||||
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), View(&decoder_out));
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
@@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,25 @@ class OnlineStream::Impl {
|
||||
|
||||
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
||||
|
||||
void SetKeywordResult(const TransducerKeywordResult &r) {
|
||||
keyword_result_ = r;
|
||||
}
|
||||
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) {
|
||||
if (remove_duplicates) {
|
||||
if (!prev_keyword_result_.timestamps.empty() &&
|
||||
!keyword_result_.timestamps.empty() &&
|
||||
keyword_result_.timestamps[0] <=
|
||||
prev_keyword_result_.timestamps.back()) {
|
||||
return empty_keyword_result_;
|
||||
} else {
|
||||
prev_keyword_result_ = keyword_result_;
|
||||
}
|
||||
return keyword_result_;
|
||||
} else {
|
||||
return keyword_result_;
|
||||
}
|
||||
}
|
||||
|
||||
OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
|
||||
|
||||
void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
|
||||
@@ -93,6 +112,9 @@ class OnlineStream::Impl {
|
||||
int32_t start_frame_index_ = 0; // never reset
|
||||
int32_t segment_ = 0;
|
||||
OnlineTransducerDecoderResult result_;
|
||||
TransducerKeywordResult prev_keyword_result_;
|
||||
TransducerKeywordResult keyword_result_;
|
||||
TransducerKeywordResult empty_keyword_result_;
|
||||
OnlineCtcDecoderResult ctc_result_;
|
||||
std::vector<Ort::Value> states_; // states for transducer or ctc models
|
||||
std::vector<float> paraformer_feat_cache_;
|
||||
@@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) {
|
||||
impl_->SetKeywordResult(r);
|
||||
}
|
||||
|
||||
TransducerKeywordResult &OnlineStream::GetKeywordResult(
|
||||
bool remove_duplicates /*=false*/) {
|
||||
return impl_->GetKeywordResult(remove_duplicates);
|
||||
}
|
||||
|
||||
OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
|
||||
return impl_->GetCtcResult();
|
||||
}
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class TransducerKeywordResult;
|
||||
class OnlineStream {
|
||||
public:
|
||||
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
||||
@@ -76,6 +78,9 @@ class OnlineStream {
|
||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||
OnlineTransducerDecoderResult &GetResult();
|
||||
|
||||
void SetKeywordResult(const TransducerKeywordResult &r);
|
||||
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false);
|
||||
|
||||
void SetCtcResult(const OnlineCtcDecoderResult &r);
|
||||
OnlineCtcDecoderResult &GetCtcResult();
|
||||
|
||||
@@ -92,7 +97,7 @@ class OnlineStream {
|
||||
*/
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
// for streaming parformer
|
||||
// for streaming paraformer
|
||||
std::vector<float> &GetParaformerFeatCache();
|
||||
std::vector<float> &GetParaformerEncoderOutCache();
|
||||
std::vector<float> &GetParaformerAlphaCache();
|
||||
|
||||
@@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (encoder_out_shape[0] != result->size()) {
|
||||
fprintf(stderr,
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
@@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
cur_encoder_out =
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), View(&decoder_out));
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
@@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
}
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
|
||||
122
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
Normal file
122
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
Normal file
@@ -0,0 +1,122 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
typedef struct {
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
|
||||
std::string filename;
|
||||
} Stream;
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Usage:
|
||||
|
||||
(1) Streaming transducer
|
||||
|
||||
./bin/sherpa-onnx-keyword-spotter \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--provider=cpu \
|
||||
--num-threads=2 \
|
||||
--keywords-file=keywords.txt \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
Note: It supports decoding multiple files in batches
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for provider: cpu (default), cuda, coreml.
|
||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::KeywordSpotterConfig config;
|
||||
|
||||
config.Register(&po);
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() < 1) {
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::KeywordSpotter keyword_spotter(config);
|
||||
|
||||
std::vector<Stream> ss;
|
||||
|
||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
||||
const std::string wav_filename = po.GetArg(i);
|
||||
int32_t sampling_rate = -1;
|
||||
|
||||
bool is_ok = false;
|
||||
const std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto s = keyword_spotter.CreateStream();
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
|
||||
// Note: We can call AcceptWaveform() multiple times.
|
||||
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
|
||||
// Call InputFinished() to indicate that no audio samples are available
|
||||
s->InputFinished();
|
||||
ss.push_back({std::move(s), wav_filename});
|
||||
}
|
||||
|
||||
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
|
||||
for (;;) {
|
||||
ready_streams.clear();
|
||||
for (auto &s : ss) {
|
||||
const auto p_ss = s.online_stream.get();
|
||||
if (keyword_spotter.IsReady(p_ss)) {
|
||||
ready_streams.push_back(p_ss);
|
||||
}
|
||||
std::ostringstream os;
|
||||
const auto r = keyword_spotter.GetResult(p_ss);
|
||||
if (!r.keyword.empty()) {
|
||||
os << s.filename << "\n";
|
||||
os << r.AsJsonString() << "\n\n";
|
||||
fprintf(stderr, "%s", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (ready_streams.empty()) {
|
||||
break;
|
||||
}
|
||||
keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
184
sherpa-onnx/csrc/transducer-keyword-decoder.cc
Normal file
184
sherpa-onnx/csrc/transducer-keyword-decoder.cc
Normal file
@@ -0,0 +1,184 @@
|
||||
// sherpa-onnx/csrc/transducer-keywords-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
int32_t blank_id = 0; // always 0
|
||||
TransducerKeywordResult r;
|
||||
std::vector<int64_t> blanks(context_size, -1);
|
||||
blanks.back() = blank_id;
|
||||
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
r.hyps = std::move(blank_hyp);
|
||||
return r;
|
||||
}
|
||||
|
||||
void TransducerKeywordDecoder::Decode(
|
||||
Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<TransducerKeywordResult> *result) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (encoder_out_shape[0] != result->size()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||
|
||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
std::vector<int64_t> blanks(context_size, -1);
|
||||
blanks.back() = 0; // blank_id is hardcoded to 0
|
||||
|
||||
std::vector<Hypotheses> cur;
|
||||
for (auto &r : *result) {
|
||||
cur.push_back(std::move(r.hyps));
|
||||
}
|
||||
std::vector<Hypothesis> prev;
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
// Due to merging paths with identical token sequences,
|
||||
// not all utterances have "num_active_paths" paths.
|
||||
auto hyps_row_splits = GetHypsRowSplits(cur);
|
||||
int32_t num_hyps =
|
||||
hyps_row_splits.back(); // total num hyps for all utterance
|
||||
prev.clear();
|
||||
for (auto &hyps : cur) {
|
||||
for (auto &h : hyps) {
|
||||
prev.push_back(std::move(h.second));
|
||||
}
|
||||
}
|
||||
cur.clear();
|
||||
cur.reserve(batch_size);
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
|
||||
Ort::Value cur_encoder_out =
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
cur_encoder_out =
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
|
||||
// The acoustic logprobs for current frame
|
||||
std::vector<float> logprobs(vocab_size * num_hyps);
|
||||
std::memcpy(logprobs.data(), p_logit,
|
||||
sizeof(float) * vocab_size * num_hyps);
|
||||
|
||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||
// to match what it actually contains
|
||||
float *p_logprob = p_logit;
|
||||
|
||||
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||
float log_prob = prev[i].log_prob;
|
||||
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||
*p_logprob += log_prob;
|
||||
}
|
||||
}
|
||||
p_logprob = p_logit; // we changed p_logprob in the above for loop
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
int32_t frame_offset = (*result)[b].frame_offset;
|
||||
int32_t start = hyps_row_splits[b];
|
||||
int32_t end = hyps_row_splits[b + 1];
|
||||
auto topk =
|
||||
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
|
||||
|
||||
Hypotheses hyps;
|
||||
for (auto k : topk) {
|
||||
int32_t hyp_index = k / vocab_size + start;
|
||||
int32_t new_token = k % vocab_size;
|
||||
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
|
||||
// blank is hardcoded to 0
|
||||
// also, it treats unk as blank
|
||||
if (new_token != 0 && new_token != unk_id_) {
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t + frame_offset);
|
||||
new_hyp.ys_probs.push_back(
|
||||
exp(logprobs[hyp_index * vocab_size + new_token]));
|
||||
|
||||
new_hyp.num_trailing_blanks = 0;
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
// Start matching from the start state, forget the decoder history.
|
||||
if (new_hyp.context_state->token == -1) {
|
||||
new_hyp.ys = blanks;
|
||||
new_hyp.timestamps.clear();
|
||||
new_hyp.ys_probs.clear();
|
||||
}
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
}
|
||||
new_hyp.log_prob = p_logprob[k] + context_score;
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
|
||||
auto best_hyp = hyps.GetMostProbable(false);
|
||||
|
||||
auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state);
|
||||
bool matched = std::get<0>(status);
|
||||
const ContextState *matched_state = std::get<1>(status);
|
||||
|
||||
if (matched) {
|
||||
float ys_prob = 0.0;
|
||||
int32_t length = best_hyp.ys_probs.size();
|
||||
for (int32_t i = 1; i <= matched_state->level; ++i) {
|
||||
ys_prob += best_hyp.ys_probs[i];
|
||||
}
|
||||
ys_prob /= matched_state->level;
|
||||
if (best_hyp.num_trailing_blanks > num_trailing_blanks_ &&
|
||||
ys_prob >= matched_state->ac_threshold) {
|
||||
auto &r = (*result)[b];
|
||||
r.tokens = {best_hyp.ys.end() - matched_state->level,
|
||||
best_hyp.ys.end()};
|
||||
r.timestamps = {best_hyp.timestamps.end() - matched_state->level,
|
||||
best_hyp.timestamps.end()};
|
||||
r.keyword = matched_state->phrase;
|
||||
|
||||
hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}});
|
||||
}
|
||||
}
|
||||
cur.push_back(std::move(hyps));
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(false);
|
||||
auto &r = (*result)[b];
|
||||
r.hyps = std::move(hyps);
|
||||
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||
r.frame_offset += num_frames;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
62
sherpa-onnx/csrc/transducer-keyword-decoder.h
Normal file
62
sherpa-onnx/csrc/transducer-keyword-decoder.h
Normal file
@@ -0,0 +1,62 @@
|
||||
// sherpa-onnx/csrc/transducer-keywords-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct TransducerKeywordResult {
|
||||
/// Number of frames after subsampling we have decoded so far
|
||||
int32_t frame_offset = 0;
|
||||
|
||||
/// The decoded token IDs for keywords
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// The triggered keyword
|
||||
std::string keyword;
|
||||
|
||||
/// number of trailing blank frames decoded so far
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
// used only in modified beam_search
|
||||
Hypotheses hyps;
|
||||
};
|
||||
|
||||
class TransducerKeywordDecoder {
|
||||
public:
|
||||
TransducerKeywordDecoder(OnlineTransducerModel *model,
|
||||
int32_t max_active_paths,
|
||||
int32_t num_trailing_blanks, int32_t unk_id)
|
||||
: model_(model),
|
||||
max_active_paths_(max_active_paths),
|
||||
num_trailing_blanks_(num_trailing_blanks),
|
||||
unk_id_(unk_id) {}
|
||||
|
||||
TransducerKeywordResult GetEmptyResult() const;
|
||||
|
||||
void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<TransducerKeywordResult> *result);
|
||||
|
||||
private:
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
|
||||
int32_t max_active_paths_;
|
||||
int32_t num_trailing_blanks_;
|
||||
int32_t unk_id_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||
@@ -15,16 +15,31 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
hotwords->clear();
|
||||
std::vector<int32_t> tmp;
|
||||
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *ids,
|
||||
std::vector<std::string> *phrases,
|
||||
std::vector<float> *scores,
|
||||
std::vector<float> *thresholds) {
|
||||
SHERPA_ONNX_CHECK(ids != nullptr);
|
||||
ids->clear();
|
||||
|
||||
std::vector<int32_t> tmp_ids;
|
||||
std::vector<float> tmp_scores;
|
||||
std::vector<float> tmp_thresholds;
|
||||
std::vector<std::string> tmp_phrases;
|
||||
|
||||
std::string line;
|
||||
std::string word;
|
||||
bool has_scores = false;
|
||||
bool has_thresholds = false;
|
||||
bool has_phrases = false;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
float score = 0;
|
||||
float threshold = 0;
|
||||
std::string phrase = "";
|
||||
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> syms;
|
||||
while (iss >> word) {
|
||||
if (word.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
@@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
}
|
||||
}
|
||||
if (symbol_table.contains(word)) {
|
||||
int32_t number = symbol_table[word];
|
||||
tmp.push_back(number);
|
||||
int32_t id = symbol_table[word];
|
||||
tmp_ids.push_back(id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
|
||||
"the "
|
||||
"same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
switch (word[0]) {
|
||||
case ':': // boosting score for current keyword
|
||||
score = std::stof(word.substr(1));
|
||||
has_scores = true;
|
||||
break;
|
||||
case '#': // triggering threshold (probability) for current keyword
|
||||
threshold = std::stof(word.substr(1));
|
||||
has_thresholds = true;
|
||||
break;
|
||||
case '@': // the original keyword string
|
||||
phrase = word.substr(1);
|
||||
has_phrases = true;
|
||||
break;
|
||||
default:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Cannot find ID for token %s at line: %s. (Hint: words on "
|
||||
"the same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
hotwords->push_back(std::move(tmp));
|
||||
ids->push_back(std::move(tmp_ids));
|
||||
tmp_scores.push_back(score);
|
||||
tmp_phrases.push_back(phrase);
|
||||
tmp_thresholds.push_back(threshold);
|
||||
}
|
||||
if (scores != nullptr) {
|
||||
if (has_scores) {
|
||||
scores->swap(tmp_scores);
|
||||
} else {
|
||||
scores->clear();
|
||||
}
|
||||
}
|
||||
if (phrases != nullptr) {
|
||||
if (has_phrases) {
|
||||
*phrases = std::move(tmp_phrases);
|
||||
} else {
|
||||
phrases->clear();
|
||||
}
|
||||
}
|
||||
if (thresholds != nullptr) {
|
||||
if (has_thresholds) {
|
||||
thresholds->swap(tmp_thresholds);
|
||||
} else {
|
||||
thresholds->clear();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *keywords_id,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold) {
|
||||
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
|
||||
threshold);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -26,7 +26,32 @@ namespace sherpa_onnx {
|
||||
* otherwise returns false.
|
||||
*/
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords);
|
||||
std::vector<std::vector<int32_t>> *hotwords_id);
|
||||
|
||||
/* Encode the keywords in an input stream to be tokens ids.
|
||||
*
|
||||
* @param is The input stream, it contains several lines, one hotword for each
|
||||
* line. For each hotword, the tokens (cjkchar or bpe) are separated
|
||||
* by spaces, it might contain boosting score (starting with :),
|
||||
* triggering threshold (starting with #) and keyword string (starting
|
||||
* with @) too.
|
||||
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
|
||||
* in the stream should be in the symbol_table, if not this
|
||||
* function returns fasle.
|
||||
*
|
||||
* @param keywords_id The encoded ids to be written to.
|
||||
* @param keywords The original keyword string to be written to.
|
||||
* @param boost_scores The boosting score for each keyword to be written to.
|
||||
* @param threshold The triggering threshold for each keyword to be written to.
|
||||
*
|
||||
* @return If all the symbols from ``is`` are in the symbol_table, returns true
|
||||
* otherwise returns false.
|
||||
*/
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *keywords_id,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
@@ -140,6 +141,73 @@ class SherpaOnnxVad {
|
||||
VoiceActivityDetector vad_;
|
||||
};
|
||||
|
||||
class SherpaOnnxKws {
|
||||
public:
|
||||
#if __ANDROID_API__ >= 9
|
||||
SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)
|
||||
: keyword_spotter_(mgr, config),
|
||||
stream_(keyword_spotter_.CreateStream()) {}
|
||||
#endif
|
||||
|
||||
explicit SherpaOnnxKws(const KeywordSpotterConfig &config)
|
||||
: keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}
|
||||
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
|
||||
if (input_sample_rate_ == -1) {
|
||||
input_sample_rate_ = sample_rate;
|
||||
}
|
||||
|
||||
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
|
||||
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
|
||||
tail_padding.size());
|
||||
stream_->InputFinished();
|
||||
}
|
||||
|
||||
// If keywords is an empty string, it just recreates the decoding stream
|
||||
// always returns true in this case.
|
||||
// If keywords is not empty, it will create a new decoding stream with
|
||||
// the given keywords appended to the default keywords.
|
||||
// Return false if errors occurred when adding keywords, true otherwise.
|
||||
bool Reset(const std::string &keywords = {}) {
|
||||
if (keywords.empty()) {
|
||||
stream_ = keyword_spotter_.CreateStream();
|
||||
return true;
|
||||
} else {
|
||||
auto stream = keyword_spotter_.CreateStream(keywords);
|
||||
// Set new keywords failed, the stream_ will not be updated.
|
||||
if (stream == nullptr) {
|
||||
return false;
|
||||
} else {
|
||||
stream_ = std::move(stream);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetKeyword() const {
|
||||
auto result = keyword_spotter_.GetResult(stream_.get());
|
||||
return result.keyword;
|
||||
}
|
||||
|
||||
std::vector<std::string> GetTokens() const {
|
||||
auto result = keyword_spotter_.GetResult(stream_.get());
|
||||
return result.tokens;
|
||||
}
|
||||
|
||||
bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }
|
||||
|
||||
void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }
|
||||
|
||||
private:
|
||||
KeywordSpotter keyword_spotter_;
|
||||
std::unique_ptr<OnlineStream> stream_;
|
||||
int32_t input_sample_rate_ = -1;
|
||||
};
|
||||
|
||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
OnlineRecognizerConfig ans;
|
||||
|
||||
@@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
|
||||
KeywordSpotterConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||
|
||||
//---------- decoding ----------
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.keywords_file = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsScore", "F");
|
||||
ans.keywords_score = env->GetFloatField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
|
||||
ans.keywords_threshold = env->GetFloatField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
|
||||
ans.num_trailing_blanks = env->GetIntField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
jobject feat_config = env->GetObjectField(config, fid);
|
||||
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||
|
||||
//---------- model config ----------
|
||||
fid = env->GetFieldID(cls, "modelConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
|
||||
jobject model_config = env->GetObjectField(config, fid);
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
// transducer
|
||||
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||
jobject transducer_config = env->GetObjectField(model_config, fid);
|
||||
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.encoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.decoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.joiner = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
|
||||
VadModelConfig ans;
|
||||
|
||||
@@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
|
||||
jclass stringClass = env->FindClass("java/lang/String");
|
||||
|
||||
// convert C++ list into jni string array
|
||||
jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
|
||||
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
|
||||
for (int32_t i = 0; i < size; i++) {
|
||||
// Convert the C++ string to a C string
|
||||
const char *cstr = tokens[i].c_str();
|
||||
|
||||
// Convert the C string to a jstring
|
||||
jstring jstr = env->NewStringUTF(cstr);
|
||||
|
||||
// Set the array element
|
||||
env->SetObjectArrayElement(result, i, jstr);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::SherpaOnnxKws(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::SherpaOnnxKws(config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
return model->IsReady();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
model->Decode();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
model->AcceptWaveform(sample_rate, p, n);
|
||||
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
// see
|
||||
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
|
||||
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();
|
||||
return env->NewStringUTF(text.c_str());
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
|
||||
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
|
||||
|
||||
std::string keywords_str = p_keywords;
|
||||
|
||||
bool status =
|
||||
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);
|
||||
env->ReleaseStringUTFChars(keywords, p_keywords);
|
||||
return status;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto tokens =
|
||||
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();
|
||||
int32_t size = tokens.size();
|
||||
jclass stringClass = env->FindClass("java/lang/String");
|
||||
|
||||
// convert C++ list into jni string array
|
||||
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
|
||||
for (int32_t i = 0; i < size; i++) {
|
||||
// Convert the C++ string to a C string
|
||||
const char *cstr = tokens[i].c_str();
|
||||
|
||||
@@ -28,9 +28,14 @@ def cli():
|
||||
)
|
||||
@click.option(
|
||||
"--tokens-type",
|
||||
type=str,
|
||||
type=click.Choice(
|
||||
["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
|
||||
),
|
||||
required=True,
|
||||
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
|
||||
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
|
||||
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||
""",
|
||||
)
|
||||
@click.option(
|
||||
"--bpe-model",
|
||||
@@ -42,14 +47,56 @@ def encode_text(
|
||||
):
|
||||
"""
|
||||
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
|
||||
Each line in the texts contains the original phrase, it might also contain some
|
||||
extra items, for example, the boosting score (startting with :), the triggering
|
||||
threshold (startting with #, only used in keyword spotting task) and the original
|
||||
phrase (startting with @). Note: the extra items will be kept same in the output.
|
||||
|
||||
example input 1 (tokens_type = ppinyin):
|
||||
|
||||
小爱同学 :2.0 #0.6 @小爱同学
|
||||
你好问问 :3.5 @你好问问
|
||||
小艺小艺 #0.6 @小艺小艺
|
||||
|
||||
example output 1:
|
||||
|
||||
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
|
||||
n ǐ h ǎo w èn w èn :3.5 @你好问问
|
||||
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
|
||||
|
||||
example input 2 (tokens_type = bpe):
|
||||
|
||||
HELLO WORLD :1.5 #0.4
|
||||
HI GOOGLE :2.0 #0.8
|
||||
HEY SIRI #0.35
|
||||
|
||||
example output 2:
|
||||
|
||||
▁HE LL O ▁WORLD :1.5 #0.4
|
||||
▁HI ▁GO O G LE :2.0 #0.8
|
||||
▁HE Y ▁S I RI #0.35
|
||||
"""
|
||||
texts = []
|
||||
# extra information like boosting score (start with :), triggering threshold (start with #)
|
||||
# original keyword (start with @)
|
||||
extra_info = []
|
||||
with open(input, "r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
texts.append(line.strip())
|
||||
extra = []
|
||||
text = []
|
||||
toks = line.strip().split()
|
||||
for tok in toks:
|
||||
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
|
||||
extra.append(tok)
|
||||
else:
|
||||
text.append(tok)
|
||||
texts.append(" ".join(text))
|
||||
extra_info.append(extra)
|
||||
|
||||
encoded_texts = text2token(
|
||||
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
|
||||
)
|
||||
with open(output, "w", encoding="utf8") as f:
|
||||
for txt in encoded_texts:
|
||||
for i, txt in enumerate(encoded_texts):
|
||||
txt += extra_info[i]
|
||||
f.write(" ".join(txt) + "\n")
|
||||
|
||||
@@ -6,6 +6,9 @@ from typing import List, Optional, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from pypinyin import pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
|
||||
|
||||
|
||||
def text2token(
|
||||
texts: List[str],
|
||||
@@ -23,7 +26,9 @@ def text2token(
|
||||
tokens:
|
||||
The path of the tokens.txt.
|
||||
tokens_type:
|
||||
The valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
|
||||
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||
bpe_model:
|
||||
The path of the bpe model. Only required when tokens_type is bpe or
|
||||
cjkchar+bpe.
|
||||
@@ -53,6 +58,24 @@ def text2token(
|
||||
texts_list = [list("".join(text.split())) for text in texts]
|
||||
elif tokens_type == "bpe":
|
||||
texts_list = sp.encode(texts, out_type=str)
|
||||
elif "pinyin" in tokens_type:
|
||||
for txt in texts:
|
||||
py = [x[0] for x in pinyin(txt)]
|
||||
if "ppinyin" == tokens_type:
|
||||
res = []
|
||||
for x in py:
|
||||
initial = to_initials(x, strict=False)
|
||||
final = to_finals_tone(x, strict=False)
|
||||
if initial == "" and final == "":
|
||||
res.append(x)
|
||||
else:
|
||||
if initial != "":
|
||||
res.append(initial)
|
||||
if final != "":
|
||||
res.append(final)
|
||||
texts_list.append(res)
|
||||
else:
|
||||
texts_list.append(py)
|
||||
else:
|
||||
assert (
|
||||
tokens_type == "cjkchar+bpe"
|
||||
|
||||
Reference in New Issue
Block a user