add blank_penalty for offline transducer (#542)

This commit is contained in:
chiiyeh
2024-01-25 15:00:09 +08:00
committed by GitHub
parent a9e7747736
commit 3bb3849ec5
13 changed files with 97 additions and 14 deletions

View File

@@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
""", """,
) )
def add_blank_penalty_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
def check_args(args): def check_args(args):
if not Path(args.tokens).is_file(): if not Path(args.tokens).is_file():
@@ -414,6 +427,7 @@ def get_args():
add_feature_config_args(parser) add_feature_config_args(parser)
add_decoding_args(parser) add_decoding_args(parser)
add_hotwords_args(parser) add_hotwords_args(parser)
add_blank_penalty_args(parser)
parser.add_argument( parser.add_argument(
"--port", "--port",
@@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
max_active_paths=args.max_active_paths, max_active_paths=args.max_active_paths,
hotwords_file=args.hotwords_file, hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score, hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
provider=args.provider, provider=args.provider,
) )
elif args.paraformer: elif args.paraformer:

View File

@@ -231,6 +231,18 @@ def get_args():
""", """,
) )
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@@ -335,6 +347,7 @@ def main():
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
hotwords_file=args.hotwords_file, hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score, hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
debug=args.debug, debug=args.debug,
) )
elif args.paraformer: elif args.paraformer:

View File

@@ -177,6 +177,18 @@ def get_args():
""", """,
) )
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
feature_dim=args.feature_dim, feature_dim=args.feature_dim,
decoding_method=args.decoding_method, decoding_method=args.decoding_method,
blank_penalty=args.blank_penalty,
debug=args.debug, debug=args.debug,
) )
elif args.paraformer: elif args.paraformer:

View File

@@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
} }
} }
template <typename T>
void SubtractBlank(T *in, int32_t w, int32_t h,
int32_t blank_idx, float blank_penalty) {
for (int32_t i = 0; i != h; ++i) {
in[blank_idx] -= blank_penalty;
in += w;
}
}
template <class T> template <class T>
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
std::vector<int32_t> vec_index(size); std::vector<int32_t> vec_index(size);

View File

@@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
} }
if (config_.decoding_method == "greedy_search") { if (config_.decoding_method == "greedy_search") {
decoder_ = decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
} else if (config_.decoding_method == "modified_beam_search") { } else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) { if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(config.lm_config); lm_ = OfflineLM::Create(config.lm_config);
@@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths, model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale); config_.lm_config.scale, config_.blank_penalty);
} else { } else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s", SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str()); config_.decoding_method.c_str());
@@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
config_.model_config)) { config_.model_config)) {
if (config_.decoding_method == "greedy_search") { if (config_.decoding_method == "greedy_search") {
decoder_ = decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
} else if (config_.decoding_method == "modified_beam_search") { } else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) { if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(mgr, config.lm_config); lm_ = OfflineLM::Create(mgr, config.lm_config);
@@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths, model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale); config_.lm_config.scale, config_.blank_penalty);
} else { } else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s", SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str()); config_.decoding_method.c_str());

View File

@@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("max-active-paths", &max_active_paths, po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search"); "Used only when decoding_method is modified_beam_search");
po->Register("blank-penalty", &blank_penalty,
"The penalty applied on blank symbol during decoding. "
"Note: It is a positive value. "
"Increasing value will lead to lower deletion at the cost"
"of higher insertions. "
"Currently only applicable for transducer models.");
po->Register( po->Register(
"hotwords-file", &hotwords_file, "hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each" "The file containing hotwords, one words/phrases per line, and for each"
@@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "decoding_method=\"" << decoding_method << "\", "; os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ", "; os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", "; os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ")"; os << "hotwords_score=" << hotwords_score << ", ";
os << "blank_penalty=" << blank_penalty << ")";
return os.str(); return os.str();
} }

View File

@@ -37,6 +37,8 @@ struct OfflineRecognizerConfig {
std::string hotwords_file; std::string hotwords_file;
float hotwords_score = 1.5; float hotwords_score = 1.5;
float blank_penalty = 0.0;
// only greedy_search is implemented // only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search // TODO(fangjun): Implement modified_beam_search
@@ -46,7 +48,8 @@ struct OfflineRecognizerConfig {
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
const std::string &decoding_method, int32_t max_active_paths, const std::string &decoding_method, int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score) const std::string &hotwords_file, float hotwords_score,
float blank_penalty)
: feat_config(feat_config), : feat_config(feat_config),
model_config(model_config), model_config(model_config),
lm_config(lm_config), lm_config(lm_config),
@@ -54,7 +57,8 @@ struct OfflineRecognizerConfig {
decoding_method(decoding_method), decoding_method(decoding_method),
max_active_paths(max_active_paths), max_active_paths(max_active_paths),
hotwords_file(hotwords_file), hotwords_file(hotwords_file),
hotwords_score(hotwords_score) {} hotwords_score(hotwords_score),
blank_penalty(blank_penalty) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;

View File

@@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
start += n; start += n;
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
std::move(cur_decoder_out)); std::move(cur_decoder_out));
const float *p_logit = logit.GetTensorData<float>(); float *p_logit = logit.GetTensorMutableData<float>();
bool emitted = false; bool emitted = false;
for (int32_t i = 0; i != n; ++i) { for (int32_t i = 0; i != n; ++i) {
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}
auto y = static_cast<int32_t>(std::distance( auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit), static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit), std::max_element(static_cast<const float *>(p_logit),

View File

@@ -14,8 +14,10 @@ namespace sherpa_onnx {
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
public: public:
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model) explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
: model_(model) {} float blank_penalty)
: model_(model),
blank_penalty_(blank_penalty) {}
std::vector<OfflineTransducerDecoderResult> Decode( std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length, Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
private: private:
OfflineTransducerModel *model_; // Not owned OfflineTransducerModel *model_; // Not owned
float blank_penalty_;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>(); float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty_ > 0.0) {
// assuming blank id is 0
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
}
LogSoftmax(p_logit, vocab_size, num_hyps); LogSoftmax(p_logit, vocab_size, num_hyps);
// now p_logit contains log_softmax output, we rename it to p_logprob // now p_logit contains log_softmax output, we rename it to p_logprob

View File

@@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
OfflineLM *lm, OfflineLM *lm,
int32_t max_active_paths, int32_t max_active_paths,
float lm_scale) float lm_scale,
float blank_penalty)
: model_(model), : model_(model),
lm_(lm), lm_(lm),
max_active_paths_(max_active_paths), max_active_paths_(max_active_paths),
lm_scale_(lm_scale) {} lm_scale_(lm_scale),
blank_penalty_(blank_penalty) {}
std::vector<OfflineTransducerDecoderResult> Decode( std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length, Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
int32_t max_active_paths_; int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr float lm_scale_; // used only when lm_ is not nullptr
float blank_penalty_;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def(py::init<const OfflineFeatureExtractorConfig &, .def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &, const OfflineModelConfig &, const OfflineLMConfig &,
const OfflineCtcFstDecoderConfig &, const std::string &, const OfflineCtcFstDecoderConfig &, const std::string &,
int32_t, const std::string &, float>(), int32_t, const std::string &, float, float>(),
py::arg("feat_config"), py::arg("model_config"), py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(), py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search", py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5) py::arg("hotwords_score") = 1.5,
py::arg("blank_penalty") = 0.0)
.def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config) .def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("lm_config", &PyClass::lm_config)
@@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score) .def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -48,6 +48,7 @@ class OfflineRecognizer(object):
max_active_paths: int = 4, max_active_paths: int = 4,
hotwords_file: str = "", hotwords_file: str = "",
hotwords_score: float = 1.5, hotwords_score: float = 1.5,
blank_penalty: float = 0.0,
debug: bool = False, debug: bool = False,
provider: str = "cpu", provider: str = "cpu",
): ):
@@ -81,6 +82,8 @@ class OfflineRecognizer(object):
max_active_paths: max_active_paths:
Maximum number of active paths to keep. Used only when Maximum number of active paths to keep. Used only when
decoding_method is modified_beam_search. decoding_method is modified_beam_search.
blank_penalty:
The penalty applied on blank symbol during decoding.
debug: debug:
True to show debug messages. True to show debug messages.
provider: provider:
@@ -117,6 +120,7 @@ class OfflineRecognizer(object):
decoding_method=decoding_method, decoding_method=decoding_method,
hotwords_file=hotwords_file, hotwords_file=hotwords_file,
hotwords_score=hotwords_score, hotwords_score=hotwords_score,
blank_penalty=blank_penalty,
) )
self.recognizer = _Recognizer(recognizer_config) self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config self.config = recognizer_config