treat unk as blank (#299)

This commit is contained in:
Fangjun Kuang
2023-09-07 15:12:29 +08:00
committed by GitHub
parent ffeff3b8a3
commit a12ebfab22
5 changed files with 29 additions and 12 deletions

View File

@@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
if (config.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OnlineLM::Create(config.lm_config);
@@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
config_.lm_config.scale, unk_id_);
} else if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(), unk_id_);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
@@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
if (config.decoding_method == "modified_beam_search") {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
config_.lm_config.scale, unk_id_);
} else if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(), unk_id_);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
@@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineTransducerDecoder> decoder_;
SymbolTable sym_;
Endpoint endpoint_;
int32_t unk_id_ = -1;
};
} // namespace sherpa_onnx