treat unk as blank (#299)

This commit is contained in:
Fangjun Kuang
2023-09-07 15:12:29 +08:00
committed by GitHub
parent ffeff3b8a3
commit a12ebfab22
5 changed files with 29 additions and 12 deletions

View File

@@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)), model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens), sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
if (config.decoding_method == "modified_beam_search") { if (config.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) { if (!config_.lm_config.model.empty()) {
lm_ = OnlineLM::Create(config.lm_config); lm_ = OnlineLM::Create(config.lm_config);
@@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths, model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale); config_.lm_config.scale, unk_id_);
} else if (config.decoding_method == "greedy_search") { } else if (config.decoding_method == "greedy_search") {
decoder_ = decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); model_.get(), unk_id_);
} else { } else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s", SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str()); config.decoding_method.c_str());
@@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(mgr, config.model_config)), model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens), sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
if (config.decoding_method == "modified_beam_search") { if (config.decoding_method == "modified_beam_search") {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths, model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale); config_.lm_config.scale, unk_id_);
} else if (config.decoding_method == "greedy_search") { } else if (config.decoding_method == "greedy_search") {
decoder_ = decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); model_.get(), unk_id_);
} else { } else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s", SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str()); config.decoding_method.c_str());
@@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineTransducerDecoder> decoder_; std::unique_ptr<OnlineTransducerDecoder> decoder_;
SymbolTable sym_; SymbolTable sym_;
Endpoint endpoint_; Endpoint endpoint_;
int32_t unk_id_ = -1;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
static_cast<const float *>(p_logit), static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit), std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size))); static_cast<const float *>(p_logit) + vocab_size)));
if (y != 0) { // blank id is hardcoded to 0
// also, it treats unk as blank
if (y != 0 && y != unk_id_) {
emitted = true; emitted = true;
r.tokens.push_back(y); r.tokens.push_back(y);
r.timestamps.push_back(t + r.frame_offset); r.timestamps.push_back(t + r.frame_offset);

View File

@@ -14,8 +14,9 @@ namespace sherpa_onnx {
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
public: public:
explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
: model_(model) {} int32_t unk_id)
: model_(model), unk_id_(unk_id) {}
OnlineTransducerDecoderResult GetEmptyResult() const override; OnlineTransducerDecoderResult GetEmptyResult() const override;
@@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
private: private:
OnlineTransducerModel *model_; // Not owned OnlineTransducerModel *model_; // Not owned
int32_t unk_id_;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
float context_score = 0; float context_score = 0;
auto context_state = new_hyp.context_state; auto context_state = new_hyp.context_state;
if (new_token != 0) { // blank is hardcoded to 0
// also, it treats unk as blank
if (new_token != 0 && new_token != unk_id_) {
new_hyp.ys.push_back(new_token); new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset); new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0; new_hyp.num_trailing_blanks = 0;

View File

@@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
OnlineLM *lm, OnlineLM *lm,
int32_t max_active_paths, int32_t max_active_paths,
float lm_scale) float lm_scale, int32_t unk_id)
: model_(model), : model_(model),
lm_(lm), lm_(lm),
max_active_paths_(max_active_paths), max_active_paths_(max_active_paths),
lm_scale_(lm_scale) {} lm_scale_(lm_scale),
unk_id_(unk_id) {}
OnlineTransducerDecoderResult GetEmptyResult() const override; OnlineTransducerDecoderResult GetEmptyResult() const override;
@@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
int32_t max_active_paths_; int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr float lm_scale_; // used only when lm_ is not nullptr
int32_t unk_id_;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx