treat unk as blank (#299)
This commit is contained in:
@@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||||
sym_(config.model_config.tokens),
|
sym_(config.model_config.tokens),
|
||||||
endpoint_(config_.endpoint_config) {
|
endpoint_(config_.endpoint_config) {
|
||||||
|
if (sym_.contains("<unk>")) {
|
||||||
|
unk_id_ = sym_["<unk>"];
|
||||||
|
}
|
||||||
|
|
||||||
if (config.decoding_method == "modified_beam_search") {
|
if (config.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OnlineLM::Create(config.lm_config);
|
lm_ = OnlineLM::Create(config.lm_config);
|
||||||
@@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale);
|
config_.lm_config.scale, unk_id_);
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ =
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
model_.get(), unk_id_);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config.decoding_method.c_str());
|
config.decoding_method.c_str());
|
||||||
@@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||||
sym_(mgr, config.model_config.tokens),
|
sym_(mgr, config.model_config.tokens),
|
||||||
endpoint_(config_.endpoint_config) {
|
endpoint_(config_.endpoint_config) {
|
||||||
|
if (sym_.contains("<unk>")) {
|
||||||
|
unk_id_ = sym_["<unk>"];
|
||||||
|
}
|
||||||
|
|
||||||
if (config.decoding_method == "modified_beam_search") {
|
if (config.decoding_method == "modified_beam_search") {
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale);
|
config_.lm_config.scale, unk_id_);
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ =
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
model_.get(), unk_id_);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config.decoding_method.c_str());
|
config.decoding_method.c_str());
|
||||||
@@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||||
SymbolTable sym_;
|
SymbolTable sym_;
|
||||||
Endpoint endpoint_;
|
Endpoint endpoint_;
|
||||||
|
int32_t unk_id_ = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
static_cast<const float *>(p_logit),
|
static_cast<const float *>(p_logit),
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
static_cast<const float *>(p_logit) + vocab_size)));
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
if (y != 0) {
|
// blank id is hardcoded to 0
|
||||||
|
// also, it treats unk as blank
|
||||||
|
if (y != 0 && y != unk_id_) {
|
||||||
emitted = true;
|
emitted = true;
|
||||||
r.tokens.push_back(y);
|
r.tokens.push_back(y);
|
||||||
r.timestamps.push_back(t + r.frame_offset);
|
r.timestamps.push_back(t + r.frame_offset);
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
|
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
||||||
: model_(model) {}
|
int32_t unk_id)
|
||||||
|
: model_(model), unk_id_(unk_id) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
OnlineTransducerModel *model_; // Not owned
|
OnlineTransducerModel *model_; // Not owned
|
||||||
|
int32_t unk_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
float context_score = 0;
|
float context_score = 0;
|
||||||
auto context_state = new_hyp.context_state;
|
auto context_state = new_hyp.context_state;
|
||||||
|
|
||||||
if (new_token != 0) {
|
// blank is hardcoded to 0
|
||||||
|
// also, it treats unk as blank
|
||||||
|
if (new_token != 0 && new_token != unk_id_) {
|
||||||
new_hyp.ys.push_back(new_token);
|
new_hyp.ys.push_back(new_token);
|
||||||
new_hyp.timestamps.push_back(t + frame_offset);
|
new_hyp.timestamps.push_back(t + frame_offset);
|
||||||
new_hyp.num_trailing_blanks = 0;
|
new_hyp.num_trailing_blanks = 0;
|
||||||
|
|||||||
@@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
||||||
OnlineLM *lm,
|
OnlineLM *lm,
|
||||||
int32_t max_active_paths,
|
int32_t max_active_paths,
|
||||||
float lm_scale)
|
float lm_scale, int32_t unk_id)
|
||||||
: model_(model),
|
: model_(model),
|
||||||
lm_(lm),
|
lm_(lm),
|
||||||
max_active_paths_(max_active_paths),
|
max_active_paths_(max_active_paths),
|
||||||
lm_scale_(lm_scale) {}
|
lm_scale_(lm_scale),
|
||||||
|
unk_id_(unk_id) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
|
|
||||||
int32_t max_active_paths_;
|
int32_t max_active_paths_;
|
||||||
float lm_scale_; // used only when lm_ is not nullptr
|
float lm_scale_; // used only when lm_ is not nullptr
|
||||||
|
int32_t unk_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
Reference in New Issue
Block a user