Minor fixes for rknn (#1925)
This commit is contained in:
@@ -99,7 +99,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
int32_t n_text_ctx = model_->TextCtx();
|
||||
|
||||
std::vector<int32_t> predicted_tokens;
|
||||
for (int32_t i = 0; i < n_text_ctx; ++i) {
|
||||
for (int32_t i = 0; i < n_text_ctx / 2; ++i) {
|
||||
if (max_token_id == model_->EOT()) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -65,6 +66,29 @@ bool OnlineModelConfig::Validate() const {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".rknn") ||
|
||||
EndsWith(transducer.decoder, ".rknn") ||
|
||||
EndsWith(transducer.joiner, ".rknn"))) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--provider is %s, which is not rknn, but you pass rknn model "
|
||||
"filenames. encoder: '%s', decoder: '%s', joiner: '%s'",
|
||||
provider_config.provider.c_str(), transducer.encoder.c_str(),
|
||||
transducer.decoder.c_str(), transducer.joiner.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (provider_config.provider == "rknn") {
|
||||
if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".onnx") ||
|
||||
EndsWith(transducer.decoder, ".onnx") ||
|
||||
EndsWith(transducer.joiner, ".onnx"))) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--provider is rknn, but you pass onnx model "
|
||||
"filenames. encoder: '%s', decoder: '%s', joiner: %'s'",
|
||||
transducer.encoder.c_str(), transducer.decoder.c_str(),
|
||||
transducer.joiner.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tokens_buf.empty() && FileExists(tokens)) {
|
||||
|
||||
@@ -463,8 +463,10 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
||||
}
|
||||
auto meta = Parse(custom_string);
|
||||
|
||||
for (const auto &p : meta) {
|
||||
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
|
||||
if (config_.debug) {
|
||||
for (const auto &p : meta) {
|
||||
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (meta.count("encoder_dims")) {
|
||||
|
||||
@@ -90,6 +90,8 @@ as the device_name.
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
fprintf(stderr, "Started! Please speak\n");
|
||||
|
||||
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
|
||||
|
||||
std::string last_text;
|
||||
|
||||
@@ -158,8 +158,11 @@ for a list of pre-trained models to download.
|
||||
const float rtf = s.elapsed_seconds / s.duration;
|
||||
|
||||
os << po.GetArg(i) << "\n";
|
||||
os << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
|
||||
<< ", Real time factor (RTF): " << rtf << "\n";
|
||||
os << "Number of threads: " << config.model_config.num_threads << ", "
|
||||
<< std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds
|
||||
<< ", Audio duration (s): " << s.duration
|
||||
<< ", Real time factor (RTF) = " << s.elapsed_seconds << "/"
|
||||
<< s.duration << " = " << rtf << "\n";
|
||||
const auto r = recognizer.GetResult(s.online_stream.get());
|
||||
os << r.text << "\n";
|
||||
os << r.AsJsonString() << "\n\n";
|
||||
|
||||
@@ -699,4 +699,12 @@ std::string ToString(const std::wstring &s) {
|
||||
return converter.to_bytes(s);
|
||||
}
|
||||
|
||||
bool EndsWith(const std::string &haystack, const std::string &needle) {
|
||||
if (needle.size() > haystack.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return std::equal(needle.rbegin(), needle.rend(), haystack.rbegin());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -145,6 +145,8 @@ std::wstring ToWideString(const std::string &s);
|
||||
|
||||
std::string ToString(const std::wstring &s);
|
||||
|
||||
bool EndsWith(const std::string &haystack, const std::string &needle);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_
|
||||
|
||||
Reference in New Issue
Block a user