decoder for open vocabulary keyword spotting (#505)

* various fixes to ContextGraph to support open vocabulary keywords decoder

* Add keyword spotter runtime

* Add binary

* First version works

* Minor fixes

* update text2token

* default values

* Add jni for kws

* add kws android project

* Minor fixes

* Remove unused interface

* Minor fixes

* Add workflow

* handle extra info in texts

* Minor fixes

* Add more comments

* Fix ci

* fix cpp style

* Add input box in android demo so that users can specify their keywords

* Fix cpp style

* Fix comments

* Minor fixes

* Minor fixes

* minor fixes

* Minor fixes

* Minor fixes

* Add CI

* Fix code style

* cpplint

* Fix comments

* Fix error
This commit is contained in:
Wei Kang
2024-01-20 22:52:41 +08:00
committed by GitHub
parent bf1dd3daf6
commit b6c020901a
77 changed files with 3316 additions and 68 deletions

View File

@@ -19,6 +19,8 @@ set(sources
features.cc
file-utils.cc
hypothesis.cc
keyword-spotter-impl.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
@@ -87,6 +89,7 @@ set(sources
stack.cc
symbol-table.cc
text-utils.cc
transducer-keyword-decoder.cc
transpose.cc
unbind.cc
utils.cc
@@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
endif()
add_executable(sherpa-onnx sherpa-onnx.cc)
add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
set(main_exes
sherpa-onnx
sherpa-onnx-keyword-spotter
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-tts

View File

@@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <chrono> // NOLINT
#include <cmath>
#include <map>
#include <random>
#include <string>
@@ -15,27 +16,25 @@
namespace sherpa_onnx {
TEST(ContextGraph, TestBasic) {
static void TestHelper(const std::map<std::string, float> &queries, float score,
bool strict_mode) {
std::vector<std::string> contexts_str(
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
std::vector<std::vector<int32_t>> contexts;
std::vector<float> scores;
for (int32_t i = 0; i < contexts_str.size(); ++i) {
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
}
auto context_graph = ContextGraph(contexts, 1);
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
auto context_graph = ContextGraph(contexts, 1, scores);
for (const auto &iter : queries) {
float total_scores = 0;
auto state = context_graph.Root();
for (auto q : iter.first) {
auto res = context_graph.ForwardOneStep(state, q);
total_scores += res.first;
state = res.second;
auto res = context_graph.ForwardOneStep(state, q, strict_mode);
total_scores += std::get<0>(res);
state = std::get<1>(res);
}
auto res = context_graph.Finalize(state);
EXPECT_EQ(res.second->token, -1);
@@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) {
}
}
TEST(ContextGraph, TestBasic) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
TestHelper(queries, 0, true);
}
TEST(ContextGraph, TestBasicNonStrict) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3},
{"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}};
TestHelper(queries, 0, false);
}
TEST(ContextGraph, TestCustomize) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18},
{"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5},
{"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}};
TestHelper(queries, 5, true);
}
TEST(ContextGraph, TestCustomizeNonStrict) {
auto queries = std::map<std::string, float>{
{"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84},
{"SHED", 10}, {"SHELF", 10}, {"HELL", 5},
{"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}};
TestHelper(queries, 5, false);
}
TEST(ContextGraph, Benchmark) {
std::random_device rd;
std::mt19937 mt(rd());

View File

@@ -4,22 +4,59 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <algorithm>
#include <cassert>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const {
if (!scores.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
}
if (!phrases.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
}
if (!ac_thresholds.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
}
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
float score = scores.empty() ? 0.0f : scores[i];
score = score == 0.0f ? context_score_ : score;
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
std::string phrase = phrases.empty() ? std::string() : phrases[i];
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
token, score, node->node_score + score,
is_end ? node->node_score + score : 0, j + 1,
is_end ? ac_threshold : 0.0f, is_end,
is_end ? phrase : std::string());
} else {
float token_score = std::max(score, node->next[token]->token_score);
node->next[token]->token_score = token_score;
float node_score = node->node_score + token_score;
node->next[token]->node_score = node_score;
bool is_end =
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
node->next[token]->output_score = is_end ? node_score : 0.0f;
node->next[token]->is_end = is_end;
if (j == token_ids[i].size() - 1) {
node->next[token]->phrase = phrase;
node->next[token]->ac_threshold = ac_threshold;
}
}
node = node->next[token].get();
}
@@ -27,8 +64,9 @@ void ContextGraph::Build(
FillFailOutput();
}
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
std::tuple<float, const ContextState *, const ContextState *>
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
bool strict_mode /*= true*/) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
@@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
}
score = node->node_score - state->node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
return std::make_pair(score + node->output_score, node);
const ContextState *matched_node =
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
if (!strict_mode && node->output_score != 0) {
SHERPA_ONNX_CHECK(nullptr != matched_node);
float output_score =
node->is_end ? node->node_score
: (node->output != nullptr ? node->output->node_score
: node->node_score);
return std::make_tuple(score + output_score - node->node_score, root_.get(),
matched_node);
}
return std::make_tuple(score + node->output_score, node, matched_node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
@@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
return std::make_pair(score, root_.get());
}
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
const ContextState *state) const {
bool status = false;
const ContextState *node = nullptr;
if (state->is_end) {
status = true;
node = state;
} else {
if (state->output != nullptr) {
status = true;
node = state->output;
}
}
return std::make_pair(status, node);
}
void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {

View File

@@ -6,6 +6,8 @@
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -22,34 +24,55 @@ struct ContextState {
float token_score;
float node_score;
float output_score;
int32_t level;
float ac_threshold;
bool is_end;
std::string phrase;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;
ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float output_score, bool is_end)
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
bool is_end = false, const std::string &phrase = {})
: token(token),
token_score(token_score),
node_score(node_score),
output_score(output_score),
is_end(is_end) {}
level(level),
ac_threshold(ac_threshold),
is_end(is_end),
phrase(phrase) {}
};
class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score)
: context_score_(context_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
float context_score, float ac_threshold,
const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {},
const std::vector<float> &ac_thresholds = {})
: context_score_(context_score), ac_threshold_(ac_threshold) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
root_->fail = root_.get();
Build(token_ids);
Build(token_ids, scores, phrases, ac_thresholds);
}
std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score, const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {})
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
std::vector<float>()) {}
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id,
bool strict_mode = true) const;
std::pair<bool, const ContextState *> IsMatched(
const ContextState *state) const;
std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;
@@ -57,8 +80,12 @@ class ContextGraph {
private:
float context_score_;
float ac_threshold_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const;
void FillFailOutput() const;
};

View File

@@ -28,6 +28,10 @@ struct Hypothesis {
// on which ys[i] is decoded.
std::vector<int32_t> timestamps;
// The acoustic probability for each token in ys.
// Only used for keyword spotting task.
std::vector<float> ys_probs;
// The total score of ys in log space.
// It contains only acoustic scores
double log_prob = 0;

View File

@@ -0,0 +1,33 @@
// sherpa-onnx/csrc/keyword-spotter-impl.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h"
namespace sherpa_onnx {
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
const KeywordSpotterConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerImpl>(config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
AAssetManager *mgr, const KeywordSpotterConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
#endif
} // namespace sherpa_onnx

View File

@@ -0,0 +1,48 @@
// sherpa-onnx/csrc/keyword-spotter-impl.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/online-stream.h"
namespace sherpa_onnx {
class KeywordSpotterImpl {
public:
static std::unique_ptr<KeywordSpotterImpl> Create(
const KeywordSpotterConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<KeywordSpotterImpl> Create(
AAssetManager *mgr, const KeywordSpotterConfig &config);
#endif
virtual ~KeywordSpotterImpl() = default;
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream(
const std::string &keywords) const = 0;
virtual bool IsReady(OnlineStream *s) const = 0;
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_

View File

@@ -0,0 +1,323 @@
// sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
#include <algorithm>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
static KeywordResult Convert(const TransducerKeywordResult &src,
const SymbolTable &sym_table, float frame_shift_ms,
int32_t subsampling_factor,
int32_t frames_since_start) {
KeywordResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
r.keyword = src.keyword;
bool from_tokens = src.keyword.empty();
for (auto i : src.tokens) {
auto sym = sym_table[i];
if (from_tokens) {
r.keyword.append(sym);
}
r.tokens.push_back(std::move(sym));
}
if (from_tokens && r.keyword.size()) {
r.keyword = r.keyword.substr(1);
}
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
r.timestamps.push_back(time);
}
r.start_time = frames_since_start * frame_shift_ms / 1000.;
return r;
}
class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
public:
explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
InitKeywords();
decoder_ = std::make_unique<TransducerKeywordDecoder>(
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
unk_id_);
}
#if __ANDROID_API__ >= 9
KeywordSpotterTransducerImpl(AAssetManager *mgr,
const KeywordSpotterConfig &config)
: config_(config),
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
InitKeywords(mgr);
decoder_ = std::make_unique<TransducerKeywordDecoder>(
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
unk_id_);
}
#endif
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph_);
InitOnlineStream(stream.get());
return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::string &keywords) const override {
auto kws = std::regex_replace(keywords, std::regex("/"), "\n");
std::istringstream is(kws);
std::vector<std::vector<int32_t>> current_ids;
std::vector<std::string> current_kws;
std::vector<float> current_scores;
std::vector<float> current_thresholds;
if (!EncodeKeywords(is, sym_, &current_ids, &current_kws, &current_scores,
&current_thresholds)) {
SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
return nullptr;
}
int32_t num_kws = current_ids.size();
int32_t num_default_kws = keywords_id_.size();
current_ids.insert(current_ids.end(), keywords_id_.begin(),
keywords_id_.end());
if (!current_kws.empty() && !keywords_.empty()) {
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
} else if (!current_kws.empty() && keywords_.empty()) {
current_kws.insert(current_kws.end(), num_default_kws, std::string());
} else if (current_kws.empty() && !keywords_.empty()) {
current_kws.insert(current_kws.end(), num_kws, std::string());
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
} else {
// Do nothing.
}
if (!current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else if (!current_scores.empty() && boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_default_kws,
config_.keywords_score);
} else if (current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_kws,
config_.keywords_score);
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else {
// Do nothing.
}
if (!current_thresholds.empty() && !thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
thresholds_.end());
} else if (!current_thresholds.empty() && thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), num_default_kws,
config_.keywords_threshold);
} else if (current_thresholds.empty() && !thresholds_.empty()) {
current_thresholds.insert(current_thresholds.end(), num_kws,
config_.keywords_threshold);
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
thresholds_.end());
} else {
// Do nothing.
}
auto keywords_graph = std::make_shared<ContextGraph>(
current_ids, config_.keywords_score, config_.keywords_threshold,
current_scores, current_kws, current_thresholds);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph);
InitOnlineStream(stream.get());
return stream;
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
int32_t feature_dim = ss[0]->FeatureDim();
std::vector<TransducerKeywordResult> results(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
for (int32_t i = 0; i != n; ++i) {
SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr);
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);
// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);
results[i] = std::move(ss[i]->GetKeywordResult());
states_vec[i] = std::move(ss[i]->GetStates());
all_processed_frames[i] = num_processed_frames;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());
std::array<int64_t, 1> processed_frames_shape{
static_cast<int64_t>(all_processed_frames.size())};
Ort::Value processed_frames = Ort::Value::CreateTensor(
memory_info, all_processed_frames.data(), all_processed_frames.size(),
processed_frames_shape.data(), processed_frames_shape.size());
auto states = model_->StackStates(states_vec);
auto pair = model_->RunEncoder(std::move(x), std::move(states),
std::move(processed_frames));
decoder_->Decode(std::move(pair.first), ss, &results);
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(pair.second);
for (int32_t i = 0; i != n; ++i) {
ss[i]->SetKeywordResult(results[i]);
ss[i]->SetStates(std::move(next_states[i]));
}
}
KeywordResult GetResult(OnlineStream *s) const override {
TransducerKeywordResult decoder_result = s->GetKeywordResult(true);
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetNumFramesSinceStart());
}
private:
void InitKeywords(std::istream &is) {
if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_,
&thresholds_)) {
SHERPA_ONNX_LOGE("Encode keywords failed.");
exit(-1);
}
keywords_graph_ = std::make_shared<ContextGraph>(
keywords_id_, config_.keywords_score, config_.keywords_threshold,
boost_scores_, keywords_, thresholds_);
}
void InitKeywords() {
// each line in keywords_file contains space-separated words
std::ifstream is(config_.keywords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
config_.keywords_file.c_str());
exit(-1);
}
InitKeywords(is);
}
#if __ANDROID_API__ >= 9
void InitKeywords(AAssetManager *mgr) {
// each line in keywords_file contains space-separated words
auto buf = ReadFile(mgr, config_.keywords_file);
std::istrstream is(buf.data(), buf.size());
if (!is) {
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
config_.keywords_file.c_str());
exit(-1);
}
InitKeywords(is);
}
#endif
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
SHERPA_ONNX_CHECK_EQ(r.hyps.size(), 1);
SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr);
r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root();
stream->SetKeywordResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
private:
KeywordSpotterConfig config_;
std::vector<std::vector<int32_t>> keywords_id_;
std::vector<float> boost_scores_;
std::vector<float> thresholds_;
std::vector<std::string> keywords_;
ContextGraphPtr keywords_graph_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<TransducerKeywordDecoder> decoder_;
SymbolTable sym_;
int32_t unk_id_ = -1;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_

View File

@@ -0,0 +1,152 @@
// sherpa-onnx/csrc/keyword-spotter.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include <assert.h>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
namespace sherpa_onnx {
std::string KeywordResult::AsJsonString() const {
std::ostringstream os;
os << "{";
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"keyword\""
<< ": ";
os << "\"" << keyword << "\""
<< ", ";
os << "\""
<< "timestamps"
<< "\""
<< ": ";
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "], ";
os << "\""
<< "tokens"
<< "\""
<< ":";
os << "[";
sep = "";
auto oldFlags = os.flags();
for (const auto &t : tokens) {
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
os << sep << "\""
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
<< ">"
<< "\"";
os.flags(oldFlags);
} else {
os << sep << "\"" << t << "\"";
}
sep = ", ";
}
os << "]";
os << "}";
return os.str();
}
void KeywordSpotterConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("num-trailing-blanks", &num_trailing_blanks,
"The number of trailing blanks should have after the keyword.");
po->Register("keywords-score", &keywords_score,
"The bonus score for each token in context word/phrase.");
po->Register("keywords-threshold", &keywords_threshold,
"The acoustic threshold (probability) to trigger the keywords.");
po->Register(
"keywords-file", &keywords_file,
"The file containing keywords, one word/phrase per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
}
bool KeywordSpotterConfig::Validate() const {
if (keywords_file.empty()) {
SHERPA_ONNX_LOGE("Please provide --keywords-file.");
return false;
}
if (!std::ifstream(keywords_file.c_str()).good()) {
SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str());
return false;
}
return model_config.Validate();
}
std::string KeywordSpotterConfig::ToString() const {
std::ostringstream os;
os << "KeywordSpotterConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "num_trailing_blanks=" << num_trailing_blanks << ", ";
os << "keywords_score=" << keywords_score << ", ";
os << "keywords_threshold=" << keywords_threshold << ", ";
os << "keywords_file=\"" << keywords_file << "\")";
return os.str();
}
KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config)
: impl_(KeywordSpotterImpl::Create(config)) {}
#if __ANDROID_API__ >= 9
KeywordSpotter::KeywordSpotter(AAssetManager *mgr,
const KeywordSpotterConfig &config)
: impl_(KeywordSpotterImpl::Create(mgr, config)) {}
#endif
KeywordSpotter::~KeywordSpotter() = default;
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const {
return impl_->CreateStream();
}
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream(
const std::string &keywords) const {
return impl_->CreateStream(keywords);
}
bool KeywordSpotter::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const {
return impl_->GetResult(s);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,148 @@
// sherpa-onnx/csrc/keyword-spotter.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct KeywordResult {
/// The triggered keyword.
/// For English, it consists of space separated words.
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
std::string keyword;
/// Decoded results at the token level.
/// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
/// Starting time of this segment.
/// When an endpoint is detected, it will change
float start_time = 0;
/** Return a json string.
*
* The returned string contains:
* {
* "keyword": "The triggered keyword",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "start_time": x,
* }
*/
std::string AsJsonString() const;
};
struct KeywordSpotterConfig {
FeatureExtractorConfig feat_config;
OnlineModelConfig model_config;
int32_t max_active_paths = 4;
int32_t num_trailing_blanks = 1;
float keywords_score = 1.0;
float keywords_threshold = 0.25;
std::string keywords_file;
KeywordSpotterConfig() = default;
KeywordSpotterConfig(const FeatureExtractorConfig &feat_config,
const OnlineModelConfig &model_config,
int32_t max_active_paths, int32_t num_trailing_blanks,
float keywords_score, float keywords_threshold,
const std::string &keywords_file)
: feat_config(feat_config),
model_config(model_config),
max_active_paths(max_active_paths),
num_trailing_blanks(num_trailing_blanks),
keywords_score(keywords_score),
keywords_threshold(keywords_threshold),
keywords_file(keywords_file) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class KeywordSpotterImpl;
class KeywordSpotter {
public:
explicit KeywordSpotter(const KeywordSpotterConfig &config);
#if __ANDROID_API__ >= 9
KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config);
#endif
~KeywordSpotter();
/** Create a stream for decoding.
*
*/
std::unique_ptr<OnlineStream> CreateStream() const;
/** Create a stream for decoding.
*
* @param The keywords for this string, it might contain several keywords,
* the keywords are separated by "/". In each of the keywords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, keywords I LOVE YOU and HELLO WORLD, looks like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
std::unique_ptr<OnlineStream> CreateStream(const std::string &keywords) const;
/**
* Return true if the given stream has enough frames for decoding.
* Return false otherwise
*/
bool IsReady(OnlineStream *s) const;
/** Decode a single stream. */
void DecodeStream(OnlineStream *s) const {
OnlineStream *ss[1] = {s};
DecodeStreams(ss, 1);
}
/** Decode multiple streams in parallel
*
* @param ss Pointer array containing streams to be decoded.
* @param n Number of streams in `ss`.
*/
void DecodeStreams(OnlineStream **ss, int32_t n) const;
KeywordResult GetResult(OnlineStream *s) const;
private:
std::unique_ptr<KeywordSpotterImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_

View File

@@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), View(&decoder_out));
Ort::Value logit =
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
@@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
}

View File

@@ -51,6 +51,25 @@ class OnlineStream::Impl {
OnlineTransducerDecoderResult &GetResult() { return result_; }
void SetKeywordResult(const TransducerKeywordResult &r) {
keyword_result_ = r;
}
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) {
if (remove_duplicates) {
if (!prev_keyword_result_.timestamps.empty() &&
!keyword_result_.timestamps.empty() &&
keyword_result_.timestamps[0] <=
prev_keyword_result_.timestamps.back()) {
return empty_keyword_result_;
} else {
prev_keyword_result_ = keyword_result_;
}
return keyword_result_;
} else {
return keyword_result_;
}
}
OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
@@ -93,6 +112,9 @@ class OnlineStream::Impl {
int32_t start_frame_index_ = 0; // never reset
int32_t segment_ = 0;
OnlineTransducerDecoderResult result_;
TransducerKeywordResult prev_keyword_result_;
TransducerKeywordResult keyword_result_;
TransducerKeywordResult empty_keyword_result_;
OnlineCtcDecoderResult ctc_result_;
std::vector<Ort::Value> states_; // states for transducer or ctc models
std::vector<float> paraformer_feat_cache_;
@@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}
void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) {
impl_->SetKeywordResult(r);
}
TransducerKeywordResult &OnlineStream::GetKeywordResult(
bool remove_duplicates /*=false*/) {
return impl_->GetKeywordResult(remove_duplicates);
}
OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
return impl_->GetCtcResult();
}

View File

@@ -14,9 +14,11 @@
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace sherpa_onnx {
class TransducerKeywordResult;
class OnlineStream {
public:
explicit OnlineStream(const FeatureExtractorConfig &config = {},
@@ -76,6 +78,9 @@ class OnlineStream {
void SetResult(const OnlineTransducerDecoderResult &r);
OnlineTransducerDecoderResult &GetResult();
void SetKeywordResult(const TransducerKeywordResult &r);
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false);
void SetCtcResult(const OnlineCtcDecoderResult &r);
OnlineCtcDecoderResult &GetCtcResult();
@@ -92,7 +97,7 @@ class OnlineStream {
*/
const ContextGraphPtr &GetContextGraph() const;
// for streaming parformer
// for streaming paraformer
std::vector<float> &GetParaformerFeatCache();
std::vector<float> &GetParaformerEncoderOutCache();
std::vector<float> &GetParaformerAlphaCache();

View File

@@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
if (encoder_out_shape[0] != result->size()) {
fprintf(stderr,
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
SHERPA_ONNX_LOGE(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
exit(-1);
}
@@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
cur_encoder_out =
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), View(&decoder_out));
Ort::Value logit =
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
@@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);

View File

@@ -0,0 +1,122 @@
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include <stdio.h>
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
std::string filename;
} Stream;
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Usage:
(1) Streaming transducer
./bin/sherpa-onnx-keyword-spotter \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--provider=cpu \
--num-threads=2 \
--keywords-file=keywords.txt \
/path/to/foo.wav [bar.wav foobar.wav ...]
Note: It supports decoding multiple files in batches
Default value for num_threads is 2.
Valid values for provider: cpu (default), cuda, coreml.
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::KeywordSpotterConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() < 1) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::KeywordSpotter keyword_spotter(config);
std::vector<Stream> ss;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const std::string wav_filename = po.GetArg(i);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
auto s = keyword_spotter.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
// Note: We can call AcceptWaveform() multiple times.
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
tail_paddings.size());
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
ss.push_back({std::move(s), wav_filename});
}
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
for (;;) {
ready_streams.clear();
for (auto &s : ss) {
const auto p_ss = s.online_stream.get();
if (keyword_spotter.IsReady(p_ss)) {
ready_streams.push_back(p_ss);
}
std::ostringstream os;
const auto r = keyword_spotter.GetResult(p_ss);
if (!r.keyword.empty()) {
os << s.filename << "\n";
os << r.AsJsonString() << "\n\n";
fprintf(stderr, "%s", os.str().c_str());
}
}
if (ready_streams.empty()) {
break;
}
keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size());
}
return 0;
}

View File

@@ -0,0 +1,184 @@
// sherpa-onnx/csrc/transducer-keywords-decoder.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#include <algorithm>
#include <cmath>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
namespace sherpa_onnx {
TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
TransducerKeywordResult r;
std::vector<int64_t> blanks(context_size, -1);
blanks.back() = blank_id;
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
return r;
}
void TransducerKeywordDecoder::Decode(
Ort::Value encoder_out, OnlineStream **ss,
std::vector<TransducerKeywordResult> *result) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
if (encoder_out_shape[0] != result->size()) {
SHERPA_ONNX_LOGE(
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
static_cast<int32_t>(encoder_out_shape[0]),
static_cast<int32_t>(result->size()));
exit(-1);
}
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<int64_t> blanks(context_size, -1);
blanks.back() = 0; // blank_id is hardcoded to 0
std::vector<Hypotheses> cur;
for (auto &r : *result) {
cur.push_back(std::move(r.hyps));
}
std::vector<Hypothesis> prev;
for (int32_t t = 0; t != num_frames; ++t) {
// Due to merging paths with identical token sequences,
// not all utterances have "num_active_paths" paths.
auto hyps_row_splits = GetHypsRowSplits(cur);
int32_t num_hyps =
hyps_row_splits.back(); // total num hyps for all utterance
prev.clear();
for (auto &hyps : cur) {
for (auto &h : hyps) {
prev.push_back(std::move(h.second));
}
}
cur.clear();
cur.reserve(batch_size);
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
cur_encoder_out =
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
Ort::Value logit =
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
// The acoustic logprobs for current frame
std::vector<float> logprobs(vocab_size * num_hyps);
std::memcpy(logprobs.data(), p_logit,
sizeof(float) * vocab_size * num_hyps);
// now p_logit contains log_softmax output, we rename it to p_logprob
// to match what it actually contains
float *p_logprob = p_logit;
// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {
float log_prob = prev[i].log_prob;
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
*p_logprob += log_prob;
}
}
p_logprob = p_logit; // we changed p_logprob in the above for loop
for (int32_t b = 0; b != batch_size; ++b) {
int32_t frame_offset = (*result)[b].frame_offset;
int32_t start = hyps_row_splits[b];
int32_t end = hyps_row_splits[b + 1];
auto topk =
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
Hypotheses hyps;
for (auto k : topk) {
int32_t hyp_index = k / vocab_size + start;
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
float context_score = 0;
auto context_state = new_hyp.context_state;
// blank is hardcoded to 0
// also, it treats unk as blank
if (new_token != 0 && new_token != unk_id_) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.ys_probs.push_back(
exp(logprobs[hyp_index * vocab_size + new_token]));
new_hyp.num_trailing_blanks = 0;
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
context_state, new_token);
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
// Start matching from the start state, forget the decoder history.
if (new_hyp.context_state->token == -1) {
new_hyp.ys = blanks;
new_hyp.timestamps.clear();
new_hyp.ys_probs.clear();
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob = p_logprob[k] + context_score;
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
auto best_hyp = hyps.GetMostProbable(false);
auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state);
bool matched = std::get<0>(status);
const ContextState *matched_state = std::get<1>(status);
if (matched) {
float ys_prob = 0.0;
int32_t length = best_hyp.ys_probs.size();
for (int32_t i = 1; i <= matched_state->level; ++i) {
ys_prob += best_hyp.ys_probs[i];
}
ys_prob /= matched_state->level;
if (best_hyp.num_trailing_blanks > num_trailing_blanks_ &&
ys_prob >= matched_state->ac_threshold) {
auto &r = (*result)[b];
r.tokens = {best_hyp.ys.end() - matched_state->level,
best_hyp.ys.end()};
r.timestamps = {best_hyp.timestamps.end() - matched_state->level,
best_hyp.timestamps.end()};
r.keyword = matched_state->phrase;
hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}});
}
}
cur.push_back(std::move(hyps));
p_logprob += (end - start) * vocab_size;
} // for (int32_t b = 0; b != batch_size; ++b)
}
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(false);
auto &r = (*result)[b];
r.hyps = std::move(hyps);
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
r.frame_offset += num_frames;
}
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,62 @@
// sherpa-onnx/csrc/transducer-keywords-decoder.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
namespace sherpa_onnx {
struct TransducerKeywordResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
/// The decoded token IDs for keywords
std::vector<int64_t> tokens;
/// The triggered keyword
std::string keyword;
/// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std::vector<int32_t> timestamps;
// used only in modified beam_search
Hypotheses hyps;
};
class TransducerKeywordDecoder {
public:
TransducerKeywordDecoder(OnlineTransducerModel *model,
int32_t max_active_paths,
int32_t num_trailing_blanks, int32_t unk_id)
: model_(model),
max_active_paths_(max_active_paths),
num_trailing_blanks_(num_trailing_blanks),
unk_id_(unk_id) {}
TransducerKeywordResult GetEmptyResult() const;
void Decode(Ort::Value encoder_out, OnlineStream **ss,
std::vector<TransducerKeywordResult> *result);
private:
OnlineTransducerModel *model_; // Not owned
int32_t max_active_paths_;
int32_t num_trailing_blanks_;
int32_t unk_id_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_

View File

@@ -15,16 +15,31 @@
namespace sherpa_onnx {
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords) {
hotwords->clear();
std::vector<int32_t> tmp;
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *ids,
std::vector<std::string> *phrases,
std::vector<float> *scores,
std::vector<float> *thresholds) {
SHERPA_ONNX_CHECK(ids != nullptr);
ids->clear();
std::vector<int32_t> tmp_ids;
std::vector<float> tmp_scores;
std::vector<float> tmp_thresholds;
std::vector<std::string> tmp_phrases;
std::string line;
std::string word;
bool has_scores = false;
bool has_thresholds = false;
bool has_phrases = false;
while (std::getline(is, line)) {
float score = 0;
float threshold = 0;
std::string phrase = "";
std::istringstream iss(line);
std::vector<std::string> syms;
while (iss >> word) {
if (word.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
@@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
}
}
if (symbol_table.contains(word)) {
int32_t number = symbol_table[word];
tmp.push_back(number);
int32_t id = symbol_table[word];
tmp_ids.push_back(id);
} else {
SHERPA_ONNX_LOGE(
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
"the "
"same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
switch (word[0]) {
case ':': // boosting score for current keyword
score = std::stof(word.substr(1));
has_scores = true;
break;
case '#': // triggering threshold (probability) for current keyword
threshold = std::stof(word.substr(1));
has_thresholds = true;
break;
case '@': // the original keyword string
phrase = word.substr(1);
has_phrases = true;
break;
default:
SHERPA_ONNX_LOGE(
"Cannot find ID for token %s at line: %s. (Hint: words on "
"the same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
}
}
}
hotwords->push_back(std::move(tmp));
ids->push_back(std::move(tmp_ids));
tmp_scores.push_back(score);
tmp_phrases.push_back(phrase);
tmp_thresholds.push_back(threshold);
}
if (scores != nullptr) {
if (has_scores) {
scores->swap(tmp_scores);
} else {
scores->clear();
}
}
if (phrases != nullptr) {
if (has_phrases) {
*phrases = std::move(tmp_phrases);
} else {
phrases->clear();
}
}
if (thresholds != nullptr) {
if (has_thresholds) {
thresholds->swap(tmp_thresholds);
} else {
thresholds->clear();
}
}
return true;
}
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords) {
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
}
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *keywords_id,
std::vector<std::string> *keywords,
std::vector<float> *boost_scores,
std::vector<float> *threshold) {
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
threshold);
}
} // namespace sherpa_onnx

View File

@@ -26,7 +26,32 @@ namespace sherpa_onnx {
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords);
std::vector<std::vector<int32_t>> *hotwords_id);
/* Encode the keywords in an input stream to be tokens ids.
*
* @param is The input stream, it contains several lines, one hotword for each
* line. For each hotword, the tokens (cjkchar or bpe) are separated
* by spaces, it might contain boosting score (starting with :),
* triggering threshold (starting with #) and keyword string (starting
* with @) too.
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
* in the stream should be in the symbol_table, if not this
* function returns fasle.
*
* @param keywords_id The encoded ids to be written to.
* @param keywords The original keyword string to be written to.
* @param boost_scores The boosting score for each keyword to be written to.
* @param threshold The triggering threshold for each keyword to be written to.
*
* @return If all the symbols from ``is`` are in the symbol_table, returns true
* otherwise returns false.
*/
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *keywords_id,
std::vector<std::string> *keywords,
std::vector<float> *boost_scores,
std::vector<float> *threshold);
} // namespace sherpa_onnx

View File

@@ -21,6 +21,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-tts.h"
@@ -140,6 +141,73 @@ class SherpaOnnxVad {
VoiceActivityDetector vad_;
};
class SherpaOnnxKws {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)
: keyword_spotter_(mgr, config),
stream_(keyword_spotter_.CreateStream()) {}
#endif
explicit SherpaOnnxKws(const KeywordSpotterConfig &config)
: keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
if (input_sample_rate_ == -1) {
input_sample_rate_ = sample_rate;
}
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
}
// If keywords is an empty string, it just recreates the decoding stream
// always returns true in this case.
// If keywords is not empty, it will create a new decoding stream with
// the given keywords appended to the default keywords.
// Return false if errors occurred when adding keywords, true otherwise.
bool Reset(const std::string &keywords = {}) {
if (keywords.empty()) {
stream_ = keyword_spotter_.CreateStream();
return true;
} else {
auto stream = keyword_spotter_.CreateStream(keywords);
// Set new keywords failed, the stream_ will not be updated.
if (stream == nullptr) {
return false;
} else {
stream_ = std::move(stream);
return true;
}
}
}
std::string GetKeyword() const {
auto result = keyword_spotter_.GetResult(stream_.get());
return result.keyword;
}
std::vector<std::string> GetTokens() const {
auto result = keyword_spotter_.GetResult(stream_.get());
return result.tokens;
}
bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }
void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }
private:
KeywordSpotter keyword_spotter_;
std::unique_ptr<OnlineStream> stream_;
int32_t input_sample_rate_ = -1;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
@@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
return ans;
}
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
KeywordSpotterConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.keywords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "keywordsScore", "F");
ans.keywords_score = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
ans.keywords_threshold = env->GetFloatField(config, fid);
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
ans.num_trailing_blanks = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
VadModelConfig ans;
@@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
jclass stringClass = env->FindClass("java/lang/String");
// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
for (int32_t i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char *cstr = tokens[i].c_str();
// Convert the C string to a jstring
jstring jstr = env->NewStringUTF(cstr);
// Set the array element
env->SetObjectArrayElement(result, i, jstr);
}
return result;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxKws(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxKws(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
return model->IsReady();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
model->Decode();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
// see
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();
return env->NewStringUTF(text.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
std::string keywords_str = p_keywords;
bool status =
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);
env->ReleaseStringUTFChars(keywords, p_keywords);
return status;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
jlong ptr) {
auto tokens =
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();
int32_t size = tokens.size();
jclass stringClass = env->FindClass("java/lang/String");
// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
for (int32_t i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char *cstr = tokens[i].c_str();

View File

@@ -28,9 +28,14 @@ def cli():
)
@click.option(
"--tokens-type",
type=str,
type=click.Choice(
["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
),
required=True,
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
""",
)
@click.option(
"--bpe-model",
@@ -42,14 +47,56 @@ def encode_text(
):
"""
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
Each line in the texts contains the original phrase, it might also contain some
extra items, for example, the boosting score (startting with :), the triggering
threshold (startting with #, only used in keyword spotting task) and the original
phrase (startting with @). Note: the extra items will be kept same in the output.
example input 1 (tokens_type = ppinyin):
小爱同学 :2.0 #0.6 @小爱同学
你好问问 :3.5 @你好问问
小艺小艺 #0.6 @小艺小艺
example output 1:
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
n ǐ h ǎo w èn w èn :3.5 @你好问问
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
example input 2 (tokens_type = bpe):
HELLO WORLD :1.5 #0.4
HI GOOGLE :2.0 #0.8
HEY SIRI #0.35
example output 2:
▁HE LL O ▁WORLD :1.5 #0.4
▁HI ▁GO O G LE :2.0 #0.8
▁HE Y ▁S I RI #0.35
"""
texts = []
# extra information like boosting score (start with :), triggering threshold (start with #)
# original keyword (start with @)
extra_info = []
with open(input, "r", encoding="utf8") as f:
for line in f:
texts.append(line.strip())
extra = []
text = []
toks = line.strip().split()
for tok in toks:
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
extra.append(tok)
else:
text.append(tok)
texts.append(" ".join(text))
extra_info.append(extra)
encoded_texts = text2token(
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
)
with open(output, "w", encoding="utf8") as f:
for txt in encoded_texts:
for i, txt in enumerate(encoded_texts):
txt += extra_info[i]
f.write(" ".join(txt) + "\n")

View File

@@ -6,6 +6,9 @@ from typing import List, Optional, Union
import sentencepiece as spm
from pypinyin import pinyin
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
def text2token(
texts: List[str],
@@ -23,7 +26,9 @@ def text2token(
tokens:
The path of the tokens.txt.
tokens_type:
The valid values are cjkchar, bpe, cjkchar+bpe.
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
bpe_model:
The path of the bpe model. Only required when tokens_type is bpe or
cjkchar+bpe.
@@ -53,6 +58,24 @@ def text2token(
texts_list = [list("".join(text.split())) for text in texts]
elif tokens_type == "bpe":
texts_list = sp.encode(texts, out_type=str)
elif "pinyin" in tokens_type:
for txt in texts:
py = [x[0] for x in pinyin(txt)]
if "ppinyin" == tokens_type:
res = []
for x in py:
initial = to_initials(x, strict=False)
final = to_finals_tone(x, strict=False)
if initial == "" and final == "":
res.append(x)
else:
if initial != "":
res.append(initial)
if final != "":
res.append(final)
texts_list.append(res)
else:
texts_list.append(py)
else:
assert (
tokens_type == "cjkchar+bpe"