Add timestamps for streaming ASR. (#123)

This commit is contained in:
Fangjun Kuang
2023-04-19 16:02:37 +08:00
committed by GitHub
parent 4b5d2887cb
commit ad05f52666
11 changed files with 170 additions and 19 deletions

View File

@@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
include(asio) include(asio)
endif() endif()
include(json)
add_subdirectory(sherpa-onnx) add_subdirectory(sherpa-onnx)
if(SHERPA_ONNX_ENABLE_C_API) if(SHERPA_ONNX_ENABLE_C_API)

44
cmake/json.cmake Normal file
View File

@@ -0,0 +1,44 @@
function(download_json)
include(FetchContent)
set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")
set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")
set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")
# If you don't have access to the Internet,
# please pre-download json
set(possible_file_locations
$ENV{HOME}/Downloads/json-3.11.2.tar.gz
${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz
${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz
/tmp/json-3.11.2.tar.gz
/star-fj/fangjun/download/github/json-3.11.2.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(json_URL "${f}")
file(TO_CMAKE_PATH "${json_URL}" json_URL)
set(json_URL2)
break()
endif()
endforeach()
FetchContent_Declare(json
URL
${json_URL}
${json_URL2}
URL_HASH ${json_HASH}
)
FetchContent_GetProperties(json)
if(NOT json_POPULATED)
message(STATUS "Downloading json from ${json_URL}")
FetchContent_Populate(json)
endif()
message(STATUS "json is downloaded to ${json_SOURCE_DIR}")
include_directories(${json_SOURCE_DIR}/include)
# Use #include "nlohmann/json.hpp"
endfunction()
download_json()

View File

@@ -8,11 +8,13 @@
#include <assert.h> #include <assert.h>
#include <algorithm> #include <algorithm>
#include <iomanip>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "nlohmann/json.hpp"
#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
@@ -22,16 +24,56 @@
namespace sherpa_onnx { namespace sherpa_onnx {
std::string OnlineRecognizerResult::AsJsonString() const {
using json = nlohmann::json;
json j;
j["text"] = text;
j["tokens"] = tokens;
j["start_time"] = start_time;
#if 1
// This branch chooses number of decimal points to keep in
// the return json string
std::ostringstream os;
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "]";
j["timestamps"] = os.str();
#else
j["timestamps"] = timestamps;
#endif
j["segment"] = segment;
j["is_final"] = is_final;
return j.dump();
}
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table) { const SymbolTable &sym_table,
std::string text; int32_t frame_shift_ms,
for (auto t : src.tokens) { int32_t subsampling_factor) {
text += sym_table[t]; OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
for (auto i : src.tokens) {
auto sym = sym_table[i];
r.text.append(sym);
r.tokens.push_back(std::move(sym));
} }
OnlineRecognizerResult ans; float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
ans.text = std::move(text); for (auto t : src.timestamps) {
return ans; float time = frame_shift_s * t;
r.timestamps.push_back(time);
}
return r;
} }
void OnlineRecognizerConfig::Register(ParseOptions *po) { void OnlineRecognizerConfig::Register(ParseOptions *po) {
@@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
OnlineTransducerDecoderResult decoder_result = s->GetResult(); OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result); decoder_->StripLeadingBlanks(&decoder_result);
return Convert(decoder_result, sym_); // TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
} }
bool IsEndpoint(OnlineStream *s) const { bool IsEndpoint(OnlineStream *s) const {

View File

@@ -7,6 +7,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
#include "android/asset_manager.h" #include "android/asset_manager.h"
@@ -22,10 +23,45 @@
namespace sherpa_onnx { namespace sherpa_onnx {
struct OnlineRecognizerResult { struct OnlineRecognizerResult {
/// Recognition results.
/// For English, it consists of space separated words.
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
std::string text; std::string text;
// TODO(fangjun): Add a method to return a json string /// Decoded results at the token level.
std::string ToString() const { return text; } /// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
/// ID of this segment
/// When an endpoint is detected, it is incremented
int32_t segment = 0;
/// Starting frame of this segment.
/// When an endpoint is detected, it will change
float start_time = 0;
/// True if this is the last segment.
bool is_final = false;
/** Return a json string.
*
* The returned string contains:
* {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
* }
*/
std::string AsJsonString() const;
}; };
struct OnlineRecognizerConfig { struct OnlineRecognizerConfig {

View File

@@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
hyps = other.hyps; hyps = other.hyps;
frame_offset = other.frame_offset;
timestamps = other.timestamps;
return *this; return *this;
} }
@@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
decoder_out = std::move(other.decoder_out); decoder_out = std::move(other.decoder_out);
hyps = std::move(other.hyps); hyps = std::move(other.hyps);
frame_offset = other.frame_offset;
timestamps = std::move(other.timestamps);
return *this; return *this;
} }

View File

@@ -13,12 +13,18 @@
namespace sherpa_onnx { namespace sherpa_onnx {
struct OnlineTransducerDecoderResult { struct OnlineTransducerDecoderResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
/// The decoded token IDs so far /// The decoded token IDs so far
std::vector<int64_t> tokens; std::vector<int64_t> tokens;
/// number of trailing blank frames decoded so far /// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0; int32_t num_trailing_blanks = 0;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std::vector<int32_t> timestamps;
// Cache decoder_out for endpointing // Cache decoder_out for endpointing
Ort::Value decoder_out; Ort::Value decoder_out;

View File

@@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
bool emitted = false; bool emitted = false;
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
auto &r = (*result)[i];
auto y = static_cast<int32_t>(std::distance( auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit), static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit), std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size))); static_cast<const float *>(p_logit) + vocab_size)));
if (y != 0) { if (y != 0) {
emitted = true; emitted = true;
(*result)[i].tokens.push_back(y); r.tokens.push_back(y);
(*result)[i].num_trailing_blanks = 0; r.timestamps.push_back(t + r.frame_offset);
r.num_trailing_blanks = 0;
} else { } else {
++(*result)[i].num_trailing_blanks; ++r.num_trailing_blanks;
} }
} }
if (emitted) { if (emitted) {
@@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
} }
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result); UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
// Update frame_offset
for (auto &r : *result) {
r.frame_offset += num_frames;
}
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
r->tokens = std::move(tokens); r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);
r->num_trailing_blanks = hyp.num_trailing_blanks; r->num_trailing_blanks = hyp.num_trailing_blanks;
} }
@@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
float *p_logit = logit.GetTensorMutableData<float>(); float *p_logit = logit.GetTensorMutableData<float>();
for (int32_t b = 0; b < batch_size; ++b) { for (int32_t b = 0; b < batch_size; ++b) {
int32_t frame_offset = (*result)[b].frame_offset;
int32_t start = hyps_num_split[b]; int32_t start = hyps_num_split[b];
int32_t end = hyps_num_split[b + 1]; int32_t end = hyps_num_split[b + 1];
LogSoftmax(p_logit, vocab_size, (end - start)); LogSoftmax(p_logit, vocab_size, (end - start));
@@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Hypothesis new_hyp = prev[hyp_index]; Hypothesis new_hyp = prev[hyp_index];
if (new_token != 0) { if (new_token != 0) {
new_hyp.ys.push_back(new_token); new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0; new_hyp.num_trailing_blanks = 0;
} else { } else {
++new_hyp.num_trailing_blanks; ++new_hyp.num_trailing_blanks;
@@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
for (int32_t b = 0; b != batch_size; ++b) { for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b]; auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true); auto best_hyp = hyps.GetMostProbable(true);
auto &r = (*result)[b];
(*result)[b].hyps = std::move(hyps); r.hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys); r.tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks; r.num_trailing_blanks = best_hyp.num_trailing_blanks;
r.frame_offset += num_frames;
} }
} }

View File

@@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() {
auto result = recognizer_->GetResult(c->s.get()); auto result = recognizer_->GetResult(c->s.get());
asio::post(server_->GetConnectionContext(), asio::post(server_->GetConnectionContext(),
[this, hdl = c->hdl, str = result.ToString()]() { [this, hdl = c->hdl, str = result.AsJsonString()]() {
server_->Send(hdl, str); server_->Send(hdl, str);
}); });
active_.erase(c->hdl); active_.erase(c->hdl);

View File

@@ -102,7 +102,7 @@ for a list of pre-trained models to download.
recognizer.DecodeStream(s.get()); recognizer.DecodeStream(s.get());
} }
std::string text = recognizer.GetResult(s.get()).text; std::string text = recognizer.GetResult(s.get()).AsJsonString();
fprintf(stderr, "Done!\n"); fprintf(stderr, "Done!\n");

View File

@@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(
sherpa_onnx::OnlineStream *s = sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr); reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s); sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);
return env->NewStringUTF(result.ToString().c_str()); return env->NewStringUTF(result.text.c_str());
} }
SHERPA_ONNX_EXTERN_C SHERPA_ONNX_EXTERN_C