add blank_penalty for offline transducer (#542)
This commit is contained in:
@@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
def add_blank_penalty_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def check_args(args):
|
||||
if not Path(args.tokens).is_file():
|
||||
@@ -414,6 +427,7 @@ def get_args():
|
||||
add_feature_config_args(parser)
|
||||
add_decoding_args(parser)
|
||||
add_hotwords_args(parser)
|
||||
add_blank_penalty_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
@@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
max_active_paths=args.max_active_paths,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
blank_penalty=args.blank_penalty,
|
||||
provider=args.provider,
|
||||
)
|
||||
elif args.paraformer:
|
||||
|
||||
@@ -231,6 +231,18 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@@ -335,6 +347,7 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
hotwords_file=args.hotwords_file,
|
||||
hotwords_score=args.hotwords_score,
|
||||
blank_penalty=args.blank_penalty,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.paraformer:
|
||||
|
||||
@@ -177,6 +177,18 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
blank_penalty=args.blank_penalty,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.paraformer:
|
||||
|
||||
@@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SubtractBlank(T *in, int32_t w, int32_t h,
|
||||
int32_t blank_idx, float blank_penalty) {
|
||||
for (int32_t i = 0; i != h; ++i) {
|
||||
in[blank_idx] -= blank_penalty;
|
||||
in += w;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
||||
std::vector<int32_t> vec_index(size);
|
||||
|
||||
@@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
} else if (config_.decoding_method == "modified_beam_search") {
|
||||
if (!config_.lm_config.model.empty()) {
|
||||
lm_ = OfflineLM::Create(config.lm_config);
|
||||
@@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale);
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
@@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
config_.model_config)) {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||
model_.get(), config_.blank_penalty);
|
||||
} else if (config_.decoding_method == "modified_beam_search") {
|
||||
if (!config_.lm_config.model.empty()) {
|
||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||
@@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale);
|
||||
config_.lm_config.scale, config_.blank_penalty);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
|
||||
@@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
|
||||
po->Register("blank-penalty", &blank_penalty,
|
||||
"The penalty applied on blank symbol during decoding. "
|
||||
"Note: It is a positive value. "
|
||||
"Increasing value will lead to lower deletion at the cost"
|
||||
"of higher insertions. "
|
||||
"Currently only applicable for transducer models.");
|
||||
|
||||
po->Register(
|
||||
"hotwords-file", &hotwords_file,
|
||||
"The file containing hotwords, one words/phrases per line, and for each"
|
||||
@@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
os << "hotwords_score=" << hotwords_score << ")";
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
os << "blank_penalty=" << blank_penalty << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -37,6 +37,8 @@ struct OfflineRecognizerConfig {
|
||||
std::string hotwords_file;
|
||||
float hotwords_score = 1.5;
|
||||
|
||||
float blank_penalty = 0.0;
|
||||
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
@@ -46,7 +48,8 @@ struct OfflineRecognizerConfig {
|
||||
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
|
||||
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
const std::string &decoding_method, int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score)
|
||||
const std::string &hotwords_file, float hotwords_score,
|
||||
float blank_penalty)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -54,7 +57,8 @@ struct OfflineRecognizerConfig {
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
hotwords_file(hotwords_file),
|
||||
hotwords_score(hotwords_score) {}
|
||||
hotwords_score(hotwords_score),
|
||||
blank_penalty(blank_penalty) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
||||
start += n;
|
||||
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
|
||||
std::move(cur_decoder_out));
|
||||
const float *p_logit = logit.GetTensorData<float>();
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
bool emitted = false;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
if (blank_penalty_ > 0.0) {
|
||||
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||
}
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
|
||||
@@ -14,8 +14,10 @@ namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||
public:
|
||||
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model)
|
||||
: model_(model) {}
|
||||
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
||||
float blank_penalty)
|
||||
: model_(model),
|
||||
blank_penalty_(blank_penalty) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
@@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||
|
||||
private:
|
||||
OfflineTransducerModel *model_; // Not owned
|
||||
float blank_penalty_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
if (blank_penalty_ > 0.0) {
|
||||
// assuming blank id is 0
|
||||
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||
}
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
|
||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||
|
||||
@@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
||||
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
|
||||
OfflineLM *lm,
|
||||
int32_t max_active_paths,
|
||||
float lm_scale)
|
||||
float lm_scale,
|
||||
float blank_penalty)
|
||||
: model_(model),
|
||||
lm_(lm),
|
||||
max_active_paths_(max_active_paths),
|
||||
lm_scale_(lm_scale) {}
|
||||
lm_scale_(lm_scale),
|
||||
blank_penalty_(blank_penalty) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
@@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
||||
|
||||
int32_t max_active_paths_;
|
||||
float lm_scale_; // used only when lm_ is not nullptr
|
||||
float blank_penalty_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||
const OfflineCtcFstDecoderConfig &, const std::string &,
|
||||
int32_t, const std::string &, float>(),
|
||||
int32_t, const std::string &, float, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 1.5)
|
||||
py::arg("hotwords_score") = 1.5,
|
||||
py::arg("blank_penalty") = 0.0)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
@@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class OfflineRecognizer(object):
|
||||
max_active_paths: int = 4,
|
||||
hotwords_file: str = "",
|
||||
hotwords_score: float = 1.5,
|
||||
blank_penalty: float = 0.0,
|
||||
debug: bool = False,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
@@ -81,6 +82,8 @@ class OfflineRecognizer(object):
|
||||
max_active_paths:
|
||||
Maximum number of active paths to keep. Used only when
|
||||
decoding_method is modified_beam_search.
|
||||
blank_penalty:
|
||||
The penalty applied on blank symbol during decoding.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
provider:
|
||||
@@ -117,6 +120,7 @@ class OfflineRecognizer(object):
|
||||
decoding_method=decoding_method,
|
||||
hotwords_file=hotwords_file,
|
||||
hotwords_score=hotwords_score,
|
||||
blank_penalty=blank_penalty,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
|
||||
Reference in New Issue
Block a user