add blank_penalty for online transducer (#548)
This commit is contained in:
@@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_);
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(), unk_id_);
|
||||
model_.get(), unk_id_, config_.blank_penalty);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
@@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_);
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(), unk_id_);
|
||||
model_.get(), unk_id_, config_.blank_penalty);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
|
||||
@@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("blank-penalty", &blank_penalty,
|
||||
"The penalty applied on blank symbol during decoding. "
|
||||
"Note: It is a positive value. "
|
||||
"Increasing value will lead to lower deletion at the cost"
|
||||
"of higher insertions. "
|
||||
"Currently only applicable for transducer models.");
|
||||
po->Register("hotwords-score", &hotwords_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
@@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "blank_penalty=" << blank_penalty << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -83,6 +83,8 @@ struct OnlineRecognizerConfig {
|
||||
float hotwords_score = 1.5;
|
||||
std::string hotwords_file;
|
||||
|
||||
float blank_penalty = 0.0;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
@@ -92,7 +94,8 @@ struct OnlineRecognizerConfig {
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score)
|
||||
const std::string &hotwords_file, float hotwords_score,
|
||||
float blank_penalty)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -101,7 +104,8 @@ struct OnlineRecognizerConfig {
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
hotwords_score(hotwords_score),
|
||||
hotwords_file(hotwords_file) {}
|
||||
hotwords_file(hotwords_file),
|
||||
blank_penalty(blank_penalty) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
const float *p_logit = logit.GetTensorData<float>();
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
|
||||
bool emitted = false;
|
||||
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
||||
auto &r = (*result)[i];
|
||||
if (blank_penalty_ > 0.0) {
|
||||
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||
}
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
|
||||
@@ -15,8 +15,9 @@ namespace sherpa_onnx {
|
||||
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||
public:
|
||||
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
||||
int32_t unk_id)
|
||||
: model_(model), unk_id_(unk_id) {}
|
||||
int32_t unk_id,
|
||||
float blank_penalty)
|
||||
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
@@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||
private:
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
int32_t unk_id_;
|
||||
float blank_penalty_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
if (blank_penalty_ > 0.0) {
|
||||
// assuming blank id is 0
|
||||
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||
}
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
|
||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||
|
||||
@@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
||||
OnlineLM *lm,
|
||||
int32_t max_active_paths,
|
||||
float lm_scale, int32_t unk_id)
|
||||
float lm_scale, int32_t unk_id,
|
||||
float blank_penalty)
|
||||
: model_(model),
|
||||
lm_(lm),
|
||||
max_active_paths_(max_active_paths),
|
||||
lm_scale_(lm_scale),
|
||||
unk_id_(unk_id) {}
|
||||
unk_id_(unk_id),
|
||||
blank_penalty_(blank_penalty) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
@@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
int32_t max_active_paths_;
|
||||
float lm_scale_; // used only when lm_ is not nullptr
|
||||
int32_t unk_id_;
|
||||
float blank_penalty_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user