diff --git a/CMakeLists.txt b/CMakeLists.txt index d84047bf..e825af00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) include(asio) endif() +include(json) + add_subdirectory(sherpa-onnx) if(SHERPA_ONNX_ENABLE_C_API) diff --git a/cmake/json.cmake b/cmake/json.cmake new file mode 100644 index 00000000..3ec935b3 --- /dev/null +++ b/cmake/json.cmake @@ -0,0 +1,44 @@ +function(download_json) + include(FetchContent) + + set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz") + set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz") + set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273") + + # If you don't have access to the Internet, + # please pre-download json + set(possible_file_locations + $ENV{HOME}/Downloads/json-3.11.2.tar.gz + ${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz + ${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz + /tmp/json-3.11.2.tar.gz + /star-fj/fangjun/download/github/json-3.11.2.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(json_URL "${f}") + file(TO_CMAKE_PATH "${json_URL}" json_URL) + set(json_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(json + URL + ${json_URL} + ${json_URL2} + URL_HASH ${json_HASH} + ) + + FetchContent_GetProperties(json) + if(NOT json_POPULATED) + message(STATUS "Downloading json from ${json_URL}") + FetchContent_Populate(json) + endif() + message(STATUS "json is downloaded to ${json_SOURCE_DIR}") + include_directories(${json_SOURCE_DIR}/include) + # Use #include "nlohmann/json.hpp" +endfunction() + +download_json() diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 3ef911ee..cc15487b 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -8,11 +8,13 @@ #include #include +#include #include #include #include #include +#include "nlohmann/json.hpp" #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" @@ -22,16 +24,56 @@ namespace sherpa_onnx { +std::string OnlineRecognizerResult::AsJsonString() const { + using json = nlohmann::json; + json j; + j["text"] = text; + j["tokens"] = tokens; + j["start_time"] = start_time; +#if 1 + // This branch chooses number of decimal points to keep in + // the return json string + std::ostringstream os; + os << "["; + std::string sep = ""; + for (auto t : timestamps) { + os << sep << std::fixed << std::setprecision(2) << t; + sep = ", "; + } + os << "]"; + j["timestamps"] = os.str(); +#else + j["timestamps"] = timestamps; +#endif + + j["segment"] = segment; + j["is_final"] = is_final; + + return j.dump(); +} + static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, - const SymbolTable &sym_table) { - std::string text; - for (auto t : src.tokens) { - text += sym_table[t]; + const SymbolTable &sym_table, + int32_t frame_shift_ms, + int32_t subsampling_factor) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + + for (auto i : src.tokens) { + auto sym = sym_table[i]; + + r.text.append(sym); + r.tokens.push_back(std::move(sym)); } - OnlineRecognizerResult ans; - ans.text = std::move(text); - return ans; + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + return r; } void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -169,7 +211,10 @@ class OnlineRecognizer::Impl { OnlineTransducerDecoderResult decoder_result = s->GetResult(); decoder_->StripLeadingBlanks(&decoder_result); - return Convert(decoder_result, sym_); + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); } bool IsEndpoint(OnlineStream *s) const { diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 521e2f1a..6a59dc3e 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -7,6 +7,7 @@ #include #include +#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -22,10 +23,45 @@ namespace sherpa_onnx { struct OnlineRecognizerResult { + /// Recognition results. + /// For English, it consists of space separated words. + /// For Chinese, it consists of Chinese words without spaces. + /// Example 1: "hello world" + /// Example 2: "你好世界" std::string text; - // TODO(fangjun): Add a method to return a json string - std::string ToString() const { return text; } + /// Decoded results at the token level. + /// For instance, for BPE-based models it consists of a list of BPE tokens. + std::vector tokens; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + std::vector timestamps; + + /// ID of this segment + /// When an endpoint is detected, it is incremented + int32_t segment = 0; + + /// Starting frame of this segment. + /// When an endpoint is detected, it will change + float start_time = 0; + + /// True if this is the last segment. + bool is_final = false; + + /** Return a json string. + * + * The returned string contains: + * { + * "text": "The recognition result", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "segment": x, + * "start_time": x, + * "is_final": true|false + * } + */ + std::string AsJsonString() const; }; struct OnlineRecognizerConfig { diff --git a/sherpa-onnx/csrc/online-transducer-decoder.cc b/sherpa-onnx/csrc/online-transducer-decoder.cc index 102b358d..7a1e5a43 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-decoder.cc @@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( hyps = other.hyps; + frame_offset = other.frame_offset; + timestamps = other.timestamps; + return *this; } @@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( decoder_out = std::move(other.decoder_out); hyps = std::move(other.hyps); + frame_offset = other.frame_offset; + timestamps = std::move(other.timestamps); + return *this; } diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 592c206c..dcfa363b 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -13,12 +13,18 @@ namespace sherpa_onnx { struct OnlineTransducerDecoderResult { + /// Number of frames after subsampling we have decoded so far + int32_t frame_offset = 0; + /// The decoded token IDs so far std::vector tokens; /// number of trailing blank frames decoded so far int32_t num_trailing_blanks = 0; + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + std::vector timestamps; + // Cache decoder_out for endpointing Ort::Value decoder_out; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index b4b191d8..0df46d32 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode( bool emitted = false; for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { + auto &r = (*result)[i]; auto y = static_cast(std::distance( static_cast(p_logit), std::max_element(static_cast(p_logit), static_cast(p_logit) + vocab_size))); if (y != 0) { emitted = true; - (*result)[i].tokens.push_back(y); - (*result)[i].num_trailing_blanks = 0; + r.tokens.push_back(y); + r.timestamps.push_back(t + r.frame_offset); + r.num_trailing_blanks = 0; } else { - ++(*result)[i].num_trailing_blanks; + ++r.num_trailing_blanks; } } if (emitted) { @@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( } UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result); + + // Update frame_offset + for (auto &r : *result) { + r.frame_offset += num_frames; + } } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index c7c00b5b..bc6f8553 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( std::vector tokens(hyp.ys.begin() + context_size, hyp.ys.end()); r->tokens = std::move(tokens); + r->timestamps = std::move(hyp.timestamps); r->num_trailing_blanks = hyp.num_trailing_blanks; } @@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( float *p_logit = logit.GetTensorMutableData(); for (int32_t b = 0; b < batch_size; ++b) { + int32_t frame_offset = (*result)[b].frame_offset; int32_t start = hyps_num_split[b]; int32_t end = hyps_num_split[b + 1]; LogSoftmax(p_logit, vocab_size, (end - start)); @@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( Hypothesis new_hyp = prev[hyp_index]; if (new_token != 0) { new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t + frame_offset); new_hyp.num_trailing_blanks = 0; } else { ++new_hyp.num_trailing_blanks; @@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( for (int32_t b = 0; b != batch_size; ++b) { auto &hyps = cur[b]; auto best_hyp = hyps.GetMostProbable(true); + auto &r = (*result)[b]; - (*result)[b].hyps = std::move(hyps); - (*result)[b].tokens = std::move(best_hyp.ys); - (*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks; + r.hyps = std::move(hyps); + r.tokens = std::move(best_hyp.ys); + r.num_trailing_blanks = best_hyp.num_trailing_blanks; + r.frame_offset += num_frames; } } diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.cc b/sherpa-onnx/csrc/online-websocket-server-impl.cc index 7b267785..a62bef25 100644 --- a/sherpa-onnx/csrc/online-websocket-server-impl.cc +++ b/sherpa-onnx/csrc/online-websocket-server-impl.cc @@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() { auto result = recognizer_->GetResult(c->s.get()); asio::post(server_->GetConnectionContext(), - [this, hdl = c->hdl, str = result.ToString()]() { + [this, hdl = c->hdl, str = result.AsJsonString()]() { server_->Send(hdl, str); }); active_.erase(c->hdl); diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 12a04744..04fcbeab 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -102,7 +102,7 @@ for a list of pre-trained models to download. recognizer.DecodeStream(s.get()); } - std::string text = recognizer.GetResult(s.get()).text; + std::string text = recognizer.GetResult(s.get()).AsJsonString(); fprintf(stderr, "Done!\n"); diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 8a861e8f..bf0ccdca 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult( sherpa_onnx::OnlineStream *s = reinterpret_cast(s_ptr); sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s); - return env->NewStringUTF(result.ToString().c_str()); + return env->NewStringUTF(result.text.c_str()); } SHERPA_ONNX_EXTERN_C