add blank_penalty for offline transducer (#542)
This commit is contained in:
@@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_blank_penalty_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_args(args):
|
def check_args(args):
|
||||||
if not Path(args.tokens).is_file():
|
if not Path(args.tokens).is_file():
|
||||||
@@ -414,6 +427,7 @@ def get_args():
|
|||||||
add_feature_config_args(parser)
|
add_feature_config_args(parser)
|
||||||
add_decoding_args(parser)
|
add_decoding_args(parser)
|
||||||
add_hotwords_args(parser)
|
add_hotwords_args(parser)
|
||||||
|
add_blank_penalty_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
@@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
max_active_paths=args.max_active_paths,
|
max_active_paths=args.max_active_paths,
|
||||||
hotwords_file=args.hotwords_file,
|
hotwords_file=args.hotwords_file,
|
||||||
hotwords_score=args.hotwords_score,
|
hotwords_score=args.hotwords_score,
|
||||||
|
blank_penalty=args.blank_penalty,
|
||||||
provider=args.provider,
|
provider=args.provider,
|
||||||
)
|
)
|
||||||
elif args.paraformer:
|
elif args.paraformer:
|
||||||
|
|||||||
@@ -231,6 +231,18 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -335,6 +347,7 @@ def main():
|
|||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
hotwords_file=args.hotwords_file,
|
hotwords_file=args.hotwords_file,
|
||||||
hotwords_score=args.hotwords_score,
|
hotwords_score=args.hotwords_score,
|
||||||
|
blank_penalty=args.blank_penalty,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
elif args.paraformer:
|
elif args.paraformer:
|
||||||
|
|||||||
@@ -177,6 +177,18 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|||||||
sample_rate=args.sample_rate,
|
sample_rate=args.sample_rate,
|
||||||
feature_dim=args.feature_dim,
|
feature_dim=args.feature_dim,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
|
blank_penalty=args.blank_penalty,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
)
|
)
|
||||||
elif args.paraformer:
|
elif args.paraformer:
|
||||||
|
|||||||
@@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SubtractBlank(T *in, int32_t w, int32_t h,
|
||||||
|
int32_t blank_idx, float blank_penalty) {
|
||||||
|
for (int32_t i = 0; i != h; ++i) {
|
||||||
|
in[blank_idx] -= blank_penalty;
|
||||||
|
in += w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
||||||
std::vector<int32_t> vec_index(size);
|
std::vector<int32_t> vec_index(size);
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
if (config_.decoding_method == "greedy_search") {
|
if (config_.decoding_method == "greedy_search") {
|
||||||
decoder_ =
|
decoder_ =
|
||||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||||
|
model_.get(), config_.blank_penalty);
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OfflineLM::Create(config.lm_config);
|
lm_ = OfflineLM::Create(config.lm_config);
|
||||||
@@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale);
|
config_.lm_config.scale, config_.blank_penalty);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config_.decoding_method.c_str());
|
config_.decoding_method.c_str());
|
||||||
@@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
config_.model_config)) {
|
config_.model_config)) {
|
||||||
if (config_.decoding_method == "greedy_search") {
|
if (config_.decoding_method == "greedy_search") {
|
||||||
decoder_ =
|
decoder_ =
|
||||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||||
|
model_.get(), config_.blank_penalty);
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||||
@@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
|
|
||||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
config_.lm_config.scale);
|
config_.lm_config.scale, config_.blank_penalty);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config_.decoding_method.c_str());
|
config_.decoding_method.c_str());
|
||||||
|
|||||||
@@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("max-active-paths", &max_active_paths,
|
po->Register("max-active-paths", &max_active_paths,
|
||||||
"Used only when decoding_method is modified_beam_search");
|
"Used only when decoding_method is modified_beam_search");
|
||||||
|
|
||||||
|
po->Register("blank-penalty", &blank_penalty,
|
||||||
|
"The penalty applied on blank symbol during decoding. "
|
||||||
|
"Note: It is a positive value. "
|
||||||
|
"Increasing value will lead to lower deletion at the cost"
|
||||||
|
"of higher insertions. "
|
||||||
|
"Currently only applicable for transducer models.");
|
||||||
|
|
||||||
po->Register(
|
po->Register(
|
||||||
"hotwords-file", &hotwords_file,
|
"hotwords-file", &hotwords_file,
|
||||||
"The file containing hotwords, one words/phrases per line, and for each"
|
"The file containing hotwords, one words/phrases per line, and for each"
|
||||||
@@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
|||||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||||
os << "max_active_paths=" << max_active_paths << ", ";
|
os << "max_active_paths=" << max_active_paths << ", ";
|
||||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||||
os << "hotwords_score=" << hotwords_score << ")";
|
os << "hotwords_score=" << hotwords_score << ", ";
|
||||||
|
os << "blank_penalty=" << blank_penalty << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ struct OfflineRecognizerConfig {
|
|||||||
std::string hotwords_file;
|
std::string hotwords_file;
|
||||||
float hotwords_score = 1.5;
|
float hotwords_score = 1.5;
|
||||||
|
|
||||||
|
float blank_penalty = 0.0;
|
||||||
|
|
||||||
// only greedy_search is implemented
|
// only greedy_search is implemented
|
||||||
// TODO(fangjun): Implement modified_beam_search
|
// TODO(fangjun): Implement modified_beam_search
|
||||||
|
|
||||||
@@ -46,7 +48,8 @@ struct OfflineRecognizerConfig {
|
|||||||
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
|
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
|
||||||
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||||
const std::string &decoding_method, int32_t max_active_paths,
|
const std::string &decoding_method, int32_t max_active_paths,
|
||||||
const std::string &hotwords_file, float hotwords_score)
|
const std::string &hotwords_file, float hotwords_score,
|
||||||
|
float blank_penalty)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
lm_config(lm_config),
|
lm_config(lm_config),
|
||||||
@@ -54,7 +57,8 @@ struct OfflineRecognizerConfig {
|
|||||||
decoding_method(decoding_method),
|
decoding_method(decoding_method),
|
||||||
max_active_paths(max_active_paths),
|
max_active_paths(max_active_paths),
|
||||||
hotwords_file(hotwords_file),
|
hotwords_file(hotwords_file),
|
||||||
hotwords_score(hotwords_score) {}
|
hotwords_score(hotwords_score),
|
||||||
|
blank_penalty(blank_penalty) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
|||||||
start += n;
|
start += n;
|
||||||
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
|
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
|
||||||
std::move(cur_decoder_out));
|
std::move(cur_decoder_out));
|
||||||
const float *p_logit = logit.GetTensorData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
bool emitted = false;
|
bool emitted = false;
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
if (blank_penalty_ > 0.0) {
|
||||||
|
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||||
|
}
|
||||||
auto y = static_cast<int32_t>(std::distance(
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
static_cast<const float *>(p_logit),
|
static_cast<const float *>(p_logit),
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
|
|||||||
@@ -14,8 +14,10 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model)
|
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
||||||
: model_(model) {}
|
float blank_penalty)
|
||||||
|
: model_(model),
|
||||||
|
blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
@@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
OfflineTransducerModel *model_; // Not owned
|
OfflineTransducerModel *model_; // Not owned
|
||||||
|
float blank_penalty_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
if (blank_penalty_ > 0.0) {
|
||||||
|
// assuming blank id is 0
|
||||||
|
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||||
|
}
|
||||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
|
|
||||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
|||||||
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
|
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
|
||||||
OfflineLM *lm,
|
OfflineLM *lm,
|
||||||
int32_t max_active_paths,
|
int32_t max_active_paths,
|
||||||
float lm_scale)
|
float lm_scale,
|
||||||
|
float blank_penalty)
|
||||||
: model_(model),
|
: model_(model),
|
||||||
lm_(lm),
|
lm_(lm),
|
||||||
max_active_paths_(max_active_paths),
|
max_active_paths_(max_active_paths),
|
||||||
lm_scale_(lm_scale) {}
|
lm_scale_(lm_scale),
|
||||||
|
blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
@@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
|
|||||||
|
|
||||||
int32_t max_active_paths_;
|
int32_t max_active_paths_;
|
||||||
float lm_scale_; // used only when lm_ is not nullptr
|
float lm_scale_; // used only when lm_ is not nullptr
|
||||||
|
float blank_penalty_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
|||||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||||
const OfflineCtcFstDecoderConfig &, const std::string &,
|
const OfflineCtcFstDecoderConfig &, const std::string &,
|
||||||
int32_t, const std::string &, float>(),
|
int32_t, const std::string &, float, float>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OfflineLMConfig(),
|
py::arg("lm_config") = OfflineLMConfig(),
|
||||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||||
py::arg("decoding_method") = "greedy_search",
|
py::arg("decoding_method") = "greedy_search",
|
||||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||||
py::arg("hotwords_score") = 1.5)
|
py::arg("hotwords_score") = 1.5,
|
||||||
|
py::arg("blank_penalty") = 0.0)
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
@@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
|||||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||||
|
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class OfflineRecognizer(object):
|
|||||||
max_active_paths: int = 4,
|
max_active_paths: int = 4,
|
||||||
hotwords_file: str = "",
|
hotwords_file: str = "",
|
||||||
hotwords_score: float = 1.5,
|
hotwords_score: float = 1.5,
|
||||||
|
blank_penalty: float = 0.0,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
provider: str = "cpu",
|
provider: str = "cpu",
|
||||||
):
|
):
|
||||||
@@ -81,6 +82,8 @@ class OfflineRecognizer(object):
|
|||||||
max_active_paths:
|
max_active_paths:
|
||||||
Maximum number of active paths to keep. Used only when
|
Maximum number of active paths to keep. Used only when
|
||||||
decoding_method is modified_beam_search.
|
decoding_method is modified_beam_search.
|
||||||
|
blank_penalty:
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
debug:
|
debug:
|
||||||
True to show debug messages.
|
True to show debug messages.
|
||||||
provider:
|
provider:
|
||||||
@@ -117,6 +120,7 @@ class OfflineRecognizer(object):
|
|||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
hotwords_file=hotwords_file,
|
hotwords_file=hotwords_file,
|
||||||
hotwords_score=hotwords_score,
|
hotwords_score=hotwords_score,
|
||||||
|
blank_penalty=blank_penalty,
|
||||||
)
|
)
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
|
|||||||
Reference in New Issue
Block a user