add blank_penalty for online transducer (#548)
This commit is contained in:
@@ -216,6 +216,18 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -290,6 +302,7 @@ def main():
|
|||||||
lm_scale=args.lm_scale,
|
lm_scale=args.lm_scale,
|
||||||
hotwords_file=args.hotwords_file,
|
hotwords_file=args.hotwords_file,
|
||||||
hotwords_score=args.hotwords_score,
|
hotwords_score=args.hotwords_score,
|
||||||
|
blank_penalty=args.blank_penalty,
|
||||||
)
|
)
|
||||||
elif args.zipformer2_ctc:
|
elif args.zipformer2_ctc:
|
||||||
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
|
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
|
||||||
|
|||||||
@@ -102,6 +102,17 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -130,6 +141,7 @@ def create_recognizer(args):
|
|||||||
provider=args.provider,
|
provider=args.provider,
|
||||||
hotwords_file=args.hotwords_file,
|
hotwords_file=args.hotwords_file,
|
||||||
hotwords_score=args.hotwords_score,
|
hotwords_score=args.hotwords_score,
|
||||||
|
blank_penalty=args.blank_penalty,
|
||||||
)
|
)
|
||||||
return recognizer
|
return recognizer
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,17 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -136,6 +147,7 @@ def create_recognizer(args):
|
|||||||
provider=args.provider,
|
provider=args.provider,
|
||||||
hotwords_file=args.hotwords_file,
|
hotwords_file=args.hotwords_file,
|
||||||
hotwords_score=args.hotwords_score,
|
hotwords_score=args.hotwords_score,
|
||||||
|
blank_penalty=args.blank_penalty,
|
||||||
)
|
)
|
||||||
return recognizer
|
return recognizer
|
||||||
|
|
||||||
|
|||||||
@@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_blank_penalty_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
def add_endpointing_args(parser: argparse.ArgumentParser):
|
def add_endpointing_args(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -284,6 +296,7 @@ def get_args():
|
|||||||
add_decoding_args(parser)
|
add_decoding_args(parser)
|
||||||
add_endpointing_args(parser)
|
add_endpointing_args(parser)
|
||||||
add_hotwords_args(parser)
|
add_hotwords_args(parser)
|
||||||
|
add_blank_penalty_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
@@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
|||||||
max_active_paths=args.num_active_paths,
|
max_active_paths=args.num_active_paths,
|
||||||
hotwords_score=args.hotwords_score,
|
hotwords_score=args.hotwords_score,
|
||||||
hotwords_file=args.hotwords_file,
|
hotwords_file=args.hotwords_file,
|
||||||
|
blank_penalty=args.blank_penalty,
|
||||||
enable_endpoint_detection=args.use_endpoint != 0,
|
enable_endpoint_detection=args.use_endpoint != 0,
|
||||||
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
||||||
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
||||||
|
|||||||
@@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale, unk_id_);
|
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
model_.get(), unk_id_);
|
model_.get(), unk_id_, config_.blank_penalty);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config.decoding_method.c_str());
|
config.decoding_method.c_str());
|
||||||
@@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale, unk_id_);
|
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
model_.get(), unk_id_);
|
model_.get(), unk_id_, config_.blank_penalty);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config.decoding_method.c_str());
|
config.decoding_method.c_str());
|
||||||
|
|||||||
@@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
|||||||
"True to enable endpoint detection. False to disable it.");
|
"True to enable endpoint detection. False to disable it.");
|
||||||
po->Register("max-active-paths", &max_active_paths,
|
po->Register("max-active-paths", &max_active_paths,
|
||||||
"beam size used in modified beam search.");
|
"beam size used in modified beam search.");
|
||||||
|
po->Register("blank-penalty", &blank_penalty,
|
||||||
|
"The penalty applied on blank symbol during decoding. "
|
||||||
|
"Note: It is a positive value. "
|
||||||
|
"Increasing value will lead to lower deletion at the cost"
|
||||||
|
"of higher insertions. "
|
||||||
|
"Currently only applicable for transducer models.");
|
||||||
po->Register("hotwords-score", &hotwords_score,
|
po->Register("hotwords-score", &hotwords_score,
|
||||||
"The bonus score for each token in context word/phrase. "
|
"The bonus score for each token in context word/phrase. "
|
||||||
"Used only when decoding_method is modified_beam_search");
|
"Used only when decoding_method is modified_beam_search");
|
||||||
@@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
|||||||
os << "max_active_paths=" << max_active_paths << ", ";
|
os << "max_active_paths=" << max_active_paths << ", ";
|
||||||
os << "hotwords_score=" << hotwords_score << ", ";
|
os << "hotwords_score=" << hotwords_score << ", ";
|
||||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||||
os << "decoding_method=\"" << decoding_method << "\")";
|
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||||
|
os << "blank_penalty=" << blank_penalty << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,6 +83,8 @@ struct OnlineRecognizerConfig {
|
|||||||
float hotwords_score = 1.5;
|
float hotwords_score = 1.5;
|
||||||
std::string hotwords_file;
|
std::string hotwords_file;
|
||||||
|
|
||||||
|
float blank_penalty = 0.0;
|
||||||
|
|
||||||
OnlineRecognizerConfig() = default;
|
OnlineRecognizerConfig() = default;
|
||||||
|
|
||||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||||
@@ -92,7 +94,8 @@ struct OnlineRecognizerConfig {
|
|||||||
bool enable_endpoint,
|
bool enable_endpoint,
|
||||||
const std::string &decoding_method,
|
const std::string &decoding_method,
|
||||||
int32_t max_active_paths,
|
int32_t max_active_paths,
|
||||||
const std::string &hotwords_file, float hotwords_score)
|
const std::string &hotwords_file, float hotwords_score,
|
||||||
|
float blank_penalty)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
lm_config(lm_config),
|
lm_config(lm_config),
|
||||||
@@ -101,7 +104,8 @@ struct OnlineRecognizerConfig {
|
|||||||
decoding_method(decoding_method),
|
decoding_method(decoding_method),
|
||||||
max_active_paths(max_active_paths),
|
max_active_paths(max_active_paths),
|
||||||
hotwords_score(hotwords_score),
|
hotwords_score(hotwords_score),
|
||||||
hotwords_file(hotwords_file) {}
|
hotwords_file(hotwords_file),
|
||||||
|
blank_penalty(blank_penalty) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
Ort::Value logit = model_->RunJoiner(
|
Ort::Value logit = model_->RunJoiner(
|
||||||
std::move(cur_encoder_out), View(&decoder_out));
|
std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
const float *p_logit = logit.GetTensorData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
|
||||||
bool emitted = false;
|
bool emitted = false;
|
||||||
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
||||||
auto &r = (*result)[i];
|
auto &r = (*result)[i];
|
||||||
|
if (blank_penalty_ > 0.0) {
|
||||||
|
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||||
|
}
|
||||||
auto y = static_cast<int32_t>(std::distance(
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
static_cast<const float *>(p_logit),
|
static_cast<const float *>(p_logit),
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ namespace sherpa_onnx {
|
|||||||
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
||||||
int32_t unk_id)
|
int32_t unk_id,
|
||||||
: model_(model), unk_id_(unk_id) {}
|
float blank_penalty)
|
||||||
|
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|
||||||
@@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
|||||||
private:
|
private:
|
||||||
OnlineTransducerModel *model_; // Not owned
|
OnlineTransducerModel *model_; // Not owned
|
||||||
int32_t unk_id_;
|
int32_t unk_id_;
|
||||||
|
float blank_penalty_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
if (blank_penalty_ > 0.0) {
|
||||||
|
// assuming blank id is 0
|
||||||
|
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||||
|
}
|
||||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
|
|
||||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||||
|
|||||||
@@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
||||||
OnlineLM *lm,
|
OnlineLM *lm,
|
||||||
int32_t max_active_paths,
|
int32_t max_active_paths,
|
||||||
float lm_scale, int32_t unk_id)
|
float lm_scale, int32_t unk_id,
|
||||||
|
float blank_penalty)
|
||||||
: model_(model),
|
: model_(model),
|
||||||
lm_(lm),
|
lm_(lm),
|
||||||
max_active_paths_(max_active_paths),
|
max_active_paths_(max_active_paths),
|
||||||
lm_scale_(lm_scale),
|
lm_scale_(lm_scale),
|
||||||
unk_id_(unk_id) {}
|
unk_id_(unk_id),
|
||||||
|
blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|
||||||
@@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
int32_t max_active_paths_;
|
int32_t max_active_paths_;
|
||||||
float lm_scale_; // used only when lm_ is not nullptr
|
float lm_scale_; // used only when lm_ is not nullptr
|
||||||
int32_t unk_id_;
|
int32_t unk_id_;
|
||||||
|
float blank_penalty_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||||
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||||
const OnlineLMConfig &, const EndpointConfig &, bool,
|
const OnlineLMConfig &, const EndpointConfig &, bool,
|
||||||
const std::string &, int32_t, const std::string &, float>(),
|
const std::string &, int32_t, const std::string &, float,
|
||||||
|
float>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||||
py::arg("hotwords_score") = 0)
|
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
@@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||||
|
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class OnlineRecognizer(object):
|
|||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
max_active_paths: int = 4,
|
max_active_paths: int = 4,
|
||||||
hotwords_score: float = 1.5,
|
hotwords_score: float = 1.5,
|
||||||
|
blank_penalty: float = 0.0,
|
||||||
hotwords_file: str = "",
|
hotwords_file: str = "",
|
||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
model_type: str = "",
|
model_type: str = "",
|
||||||
@@ -100,6 +101,8 @@ class OnlineRecognizer(object):
|
|||||||
max_active_paths:
|
max_active_paths:
|
||||||
Use only when decoding_method is modified_beam_search. It specifies
|
Use only when decoding_method is modified_beam_search. It specifies
|
||||||
the maximum number of active paths during beam search.
|
the maximum number of active paths during beam search.
|
||||||
|
blank_penalty:
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
hotwords_file:
|
hotwords_file:
|
||||||
The file containing hotwords, one words/phrases per line, and for each
|
The file containing hotwords, one words/phrases per line, and for each
|
||||||
phrase the bpe/cjkchar are separated by a space.
|
phrase the bpe/cjkchar are separated by a space.
|
||||||
@@ -172,6 +175,7 @@ class OnlineRecognizer(object):
|
|||||||
max_active_paths=max_active_paths,
|
max_active_paths=max_active_paths,
|
||||||
hotwords_score=hotwords_score,
|
hotwords_score=hotwords_score,
|
||||||
hotwords_file=hotwords_file,
|
hotwords_file=hotwords_file,
|
||||||
|
blank_penalty=blank_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user