offline transducer: treat unk as blank (#1005)
Co-authored-by: chungyi.li <chungyi.li@ailabs.tw>
This commit is contained in:
@@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
config_(config),
|
config_(config),
|
||||||
symbol_table_(config_.model_config.tokens),
|
symbol_table_(config_.model_config.tokens),
|
||||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||||
|
if (symbol_table_.Contains("<unk>")) {
|
||||||
|
unk_id_ = symbol_table_["<unk>"];
|
||||||
|
}
|
||||||
|
|
||||||
if (config_.decoding_method == "greedy_search") {
|
if (config_.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||||
model_.get(), config_.blank_penalty);
|
model_.get(), unk_id_, config_.blank_penalty);
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OfflineLM::Create(config.lm_config);
|
lm_ = OfflineLM::Create(config.lm_config);
|
||||||
@@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale, config_.blank_penalty);
|
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config_.decoding_method.c_str());
|
config_.decoding_method.c_str());
|
||||||
@@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
symbol_table_(mgr, config_.model_config.tokens),
|
symbol_table_(mgr, config_.model_config.tokens),
|
||||||
model_(std::make_unique<OfflineTransducerModel>(mgr,
|
model_(std::make_unique<OfflineTransducerModel>(mgr,
|
||||||
config_.model_config)) {
|
config_.model_config)) {
|
||||||
|
if (symbol_table_.Contains("<unk>")) {
|
||||||
|
unk_id_ = symbol_table_["<unk>"];
|
||||||
|
}
|
||||||
|
|
||||||
if (config_.decoding_method == "greedy_search") {
|
if (config_.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||||
model_.get(), config_.blank_penalty);
|
model_.get(), unk_id_, config_.blank_penalty);
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||||
@@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale, config_.blank_penalty);
|
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config_.decoding_method.c_str());
|
config_.decoding_method.c_str());
|
||||||
@@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
std::unique_ptr<OfflineTransducerModel> model_;
|
std::unique_ptr<OfflineTransducerModel> model_;
|
||||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||||
std::unique_ptr<OfflineLM> lm_;
|
std::unique_ptr<OfflineLM> lm_;
|
||||||
|
int32_t unk_id_ = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
|||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
static_cast<const float *>(p_logit) + vocab_size)));
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
p_logit += vocab_size;
|
p_logit += vocab_size;
|
||||||
if (y != 0) {
|
// blank id is hardcoded to 0
|
||||||
|
// also, it treats unk as blank
|
||||||
|
if (y != 0 && y != unk_id_) {
|
||||||
ans[i].tokens.push_back(y);
|
ans[i].tokens.push_back(y);
|
||||||
ans[i].timestamps.push_back(t);
|
ans[i].timestamps.push_back(t);
|
||||||
emitted = true;
|
emitted = true;
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ namespace sherpa_onnx {
|
|||||||
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
||||||
|
int32_t unk_id,
|
||||||
float blank_penalty)
|
float blank_penalty)
|
||||||
: model_(model), blank_penalty_(blank_penalty) {}
|
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
@@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineTransducerModel *model_; // Not owned
|
OfflineTransducerModel *model_; // Not owned
|
||||||
|
int32_t unk_id_;
|
||||||
float blank_penalty_;
|
float blank_penalty_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
|
|
||||||
float context_score = 0;
|
float context_score = 0;
|
||||||
auto context_state = new_hyp.context_state;
|
auto context_state = new_hyp.context_state;
|
||||||
if (new_token != 0) {
|
// blank is hardcoded to 0
|
||||||
// blank id is fixed to 0
|
// also, it treats unk as blank
|
||||||
|
if (new_token != 0 && new_token != unk_id_) {
|
||||||
new_hyp.ys.push_back(new_token);
|
new_hyp.ys.push_back(new_token);
|
||||||
new_hyp.timestamps.push_back(t);
|
new_hyp.timestamps.push_back(t);
|
||||||
if (context_graphs[i] != nullptr) {
|
if (context_graphs[i] != nullptr) {
|
||||||
|
|||||||
@@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
|||||||
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
|
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
|
||||||
OfflineLM *lm,
|
OfflineLM *lm,
|
||||||
int32_t max_active_paths,
|
int32_t max_active_paths,
|
||||||
float lm_scale,
|
float lm_scale, int32_t unk_id,
|
||||||
float blank_penalty)
|
float blank_penalty)
|
||||||
: model_(model),
|
: model_(model),
|
||||||
lm_(lm),
|
lm_(lm),
|
||||||
max_active_paths_(max_active_paths),
|
max_active_paths_(max_active_paths),
|
||||||
lm_scale_(lm_scale),
|
lm_scale_(lm_scale),
|
||||||
|
unk_id_(unk_id),
|
||||||
blank_penalty_(blank_penalty) {}
|
blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
@@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
|||||||
|
|
||||||
int32_t max_active_paths_;
|
int32_t max_active_paths_;
|
||||||
float lm_scale_; // used only when lm_ is not nullptr
|
float lm_scale_; // used only when lm_ is not nullptr
|
||||||
|
int32_t unk_id_;
|
||||||
float blank_penalty_;
|
float blank_penalty_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user