Add timestamps for streaming ASR. (#123)
This commit is contained in:
@@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
|||||||
include(asio)
|
include(asio)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
include(json)
|
||||||
|
|
||||||
add_subdirectory(sherpa-onnx)
|
add_subdirectory(sherpa-onnx)
|
||||||
|
|
||||||
if(SHERPA_ONNX_ENABLE_C_API)
|
if(SHERPA_ONNX_ENABLE_C_API)
|
||||||
|
|||||||
44
cmake/json.cmake
Normal file
44
cmake/json.cmake
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
function(download_json)
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")
|
||||||
|
set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")
|
||||||
|
set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")
|
||||||
|
|
||||||
|
# If you don't have access to the Internet,
|
||||||
|
# please pre-download json
|
||||||
|
set(possible_file_locations
|
||||||
|
$ENV{HOME}/Downloads/json-3.11.2.tar.gz
|
||||||
|
${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz
|
||||||
|
${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz
|
||||||
|
/tmp/json-3.11.2.tar.gz
|
||||||
|
/star-fj/fangjun/download/github/json-3.11.2.tar.gz
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
if(EXISTS ${f})
|
||||||
|
set(json_URL "${f}")
|
||||||
|
file(TO_CMAKE_PATH "${json_URL}" json_URL)
|
||||||
|
set(json_URL2)
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
FetchContent_Declare(json
|
||||||
|
URL
|
||||||
|
${json_URL}
|
||||||
|
${json_URL2}
|
||||||
|
URL_HASH ${json_HASH}
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_GetProperties(json)
|
||||||
|
if(NOT json_POPULATED)
|
||||||
|
message(STATUS "Downloading json from ${json_URL}")
|
||||||
|
FetchContent_Populate(json)
|
||||||
|
endif()
|
||||||
|
message(STATUS "json is downloaded to ${json_SOURCE_DIR}")
|
||||||
|
include_directories(${json_SOURCE_DIR}/include)
|
||||||
|
# Use #include "nlohmann/json.hpp"
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
download_json()
|
||||||
@@ -8,11 +8,13 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <iomanip>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||||
@@ -22,16 +24,56 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||||
|
using json = nlohmann::json;
|
||||||
|
json j;
|
||||||
|
j["text"] = text;
|
||||||
|
j["tokens"] = tokens;
|
||||||
|
j["start_time"] = start_time;
|
||||||
|
#if 1
|
||||||
|
// This branch chooses number of decimal points to keep in
|
||||||
|
// the return json string
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "[";
|
||||||
|
std::string sep = "";
|
||||||
|
for (auto t : timestamps) {
|
||||||
|
os << sep << std::fixed << std::setprecision(2) << t;
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
os << "]";
|
||||||
|
j["timestamps"] = os.str();
|
||||||
|
#else
|
||||||
|
j["timestamps"] = timestamps;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
j["segment"] = segment;
|
||||||
|
j["is_final"] = is_final;
|
||||||
|
|
||||||
|
return j.dump();
|
||||||
|
}
|
||||||
|
|
||||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||||
const SymbolTable &sym_table) {
|
const SymbolTable &sym_table,
|
||||||
std::string text;
|
int32_t frame_shift_ms,
|
||||||
for (auto t : src.tokens) {
|
int32_t subsampling_factor) {
|
||||||
text += sym_table[t];
|
OnlineRecognizerResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
r.timestamps.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
for (auto i : src.tokens) {
|
||||||
|
auto sym = sym_table[i];
|
||||||
|
|
||||||
|
r.text.append(sym);
|
||||||
|
r.tokens.push_back(std::move(sym));
|
||||||
}
|
}
|
||||||
|
|
||||||
OnlineRecognizerResult ans;
|
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||||
ans.text = std::move(text);
|
for (auto t : src.timestamps) {
|
||||||
return ans;
|
float time = frame_shift_s * t;
|
||||||
|
r.timestamps.push_back(time);
|
||||||
|
}
|
||||||
|
|
||||||
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||||
@@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
|
|||||||
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||||
decoder_->StripLeadingBlanks(&decoder_result);
|
decoder_->StripLeadingBlanks(&decoder_result);
|
||||||
|
|
||||||
return Convert(decoder_result, sym_);
|
// TODO(fangjun): Remember to change these constants if needed
|
||||||
|
int32_t frame_shift_ms = 10;
|
||||||
|
int32_t subsampling_factor = 4;
|
||||||
|
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsEndpoint(OnlineStream *s) const {
|
bool IsEndpoint(OnlineStream *s) const {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
#include "android/asset_manager.h"
|
#include "android/asset_manager.h"
|
||||||
@@ -22,10 +23,45 @@
|
|||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
struct OnlineRecognizerResult {
|
struct OnlineRecognizerResult {
|
||||||
|
/// Recognition results.
|
||||||
|
/// For English, it consists of space separated words.
|
||||||
|
/// For Chinese, it consists of Chinese words without spaces.
|
||||||
|
/// Example 1: "hello world"
|
||||||
|
/// Example 2: "你好世界"
|
||||||
std::string text;
|
std::string text;
|
||||||
|
|
||||||
// TODO(fangjun): Add a method to return a json string
|
/// Decoded results at the token level.
|
||||||
std::string ToString() const { return text; }
|
/// For instance, for BPE-based models it consists of a list of BPE tokens.
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
|
||||||
|
/// timestamps.size() == tokens.size()
|
||||||
|
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||||
|
std::vector<float> timestamps;
|
||||||
|
|
||||||
|
/// ID of this segment
|
||||||
|
/// When an endpoint is detected, it is incremented
|
||||||
|
int32_t segment = 0;
|
||||||
|
|
||||||
|
/// Starting frame of this segment.
|
||||||
|
/// When an endpoint is detected, it will change
|
||||||
|
float start_time = 0;
|
||||||
|
|
||||||
|
/// True if this is the last segment.
|
||||||
|
bool is_final = false;
|
||||||
|
|
||||||
|
/** Return a json string.
|
||||||
|
*
|
||||||
|
* The returned string contains:
|
||||||
|
* {
|
||||||
|
* "text": "The recognition result",
|
||||||
|
* "tokens": [x, x, x],
|
||||||
|
* "timestamps": [x, x, x],
|
||||||
|
* "segment": x,
|
||||||
|
* "start_time": x,
|
||||||
|
* "is_final": true|false
|
||||||
|
* }
|
||||||
|
*/
|
||||||
|
std::string AsJsonString() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OnlineRecognizerConfig {
|
struct OnlineRecognizerConfig {
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
|||||||
|
|
||||||
hyps = other.hyps;
|
hyps = other.hyps;
|
||||||
|
|
||||||
|
frame_offset = other.frame_offset;
|
||||||
|
timestamps = other.timestamps;
|
||||||
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
|||||||
decoder_out = std::move(other.decoder_out);
|
decoder_out = std::move(other.decoder_out);
|
||||||
hyps = std::move(other.hyps);
|
hyps = std::move(other.hyps);
|
||||||
|
|
||||||
|
frame_offset = other.frame_offset;
|
||||||
|
timestamps = std::move(other.timestamps);
|
||||||
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,18 @@
|
|||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
struct OnlineTransducerDecoderResult {
|
struct OnlineTransducerDecoderResult {
|
||||||
|
/// Number of frames after subsampling we have decoded so far
|
||||||
|
int32_t frame_offset = 0;
|
||||||
|
|
||||||
/// The decoded token IDs so far
|
/// The decoded token IDs so far
|
||||||
std::vector<int64_t> tokens;
|
std::vector<int64_t> tokens;
|
||||||
|
|
||||||
/// number of trailing blank frames decoded so far
|
/// number of trailing blank frames decoded so far
|
||||||
int32_t num_trailing_blanks = 0;
|
int32_t num_trailing_blanks = 0;
|
||||||
|
|
||||||
|
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||||
|
std::vector<int32_t> timestamps;
|
||||||
|
|
||||||
// Cache decoder_out for endpointing
|
// Cache decoder_out for endpointing
|
||||||
Ort::Value decoder_out;
|
Ort::Value decoder_out;
|
||||||
|
|
||||||
|
|||||||
@@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
|
|
||||||
bool emitted = false;
|
bool emitted = false;
|
||||||
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
||||||
|
auto &r = (*result)[i];
|
||||||
auto y = static_cast<int32_t>(std::distance(
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
static_cast<const float *>(p_logit),
|
static_cast<const float *>(p_logit),
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
static_cast<const float *>(p_logit) + vocab_size)));
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
if (y != 0) {
|
if (y != 0) {
|
||||||
emitted = true;
|
emitted = true;
|
||||||
(*result)[i].tokens.push_back(y);
|
r.tokens.push_back(y);
|
||||||
(*result)[i].num_trailing_blanks = 0;
|
r.timestamps.push_back(t + r.frame_offset);
|
||||||
|
r.num_trailing_blanks = 0;
|
||||||
} else {
|
} else {
|
||||||
++(*result)[i].num_trailing_blanks;
|
++r.num_trailing_blanks;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (emitted) {
|
if (emitted) {
|
||||||
@@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
|
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
|
||||||
|
|
||||||
|
// Update frame_offset
|
||||||
|
for (auto &r : *result) {
|
||||||
|
r.frame_offset += num_frames;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
|||||||
|
|
||||||
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
|
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
|
||||||
r->tokens = std::move(tokens);
|
r->tokens = std::move(tokens);
|
||||||
|
r->timestamps = std::move(hyp.timestamps);
|
||||||
r->num_trailing_blanks = hyp.num_trailing_blanks;
|
r->num_trailing_blanks = hyp.num_trailing_blanks;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
|
||||||
for (int32_t b = 0; b < batch_size; ++b) {
|
for (int32_t b = 0; b < batch_size; ++b) {
|
||||||
|
int32_t frame_offset = (*result)[b].frame_offset;
|
||||||
int32_t start = hyps_num_split[b];
|
int32_t start = hyps_num_split[b];
|
||||||
int32_t end = hyps_num_split[b + 1];
|
int32_t end = hyps_num_split[b + 1];
|
||||||
LogSoftmax(p_logit, vocab_size, (end - start));
|
LogSoftmax(p_logit, vocab_size, (end - start));
|
||||||
@@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
Hypothesis new_hyp = prev[hyp_index];
|
Hypothesis new_hyp = prev[hyp_index];
|
||||||
if (new_token != 0) {
|
if (new_token != 0) {
|
||||||
new_hyp.ys.push_back(new_token);
|
new_hyp.ys.push_back(new_token);
|
||||||
|
new_hyp.timestamps.push_back(t + frame_offset);
|
||||||
new_hyp.num_trailing_blanks = 0;
|
new_hyp.num_trailing_blanks = 0;
|
||||||
} else {
|
} else {
|
||||||
++new_hyp.num_trailing_blanks;
|
++new_hyp.num_trailing_blanks;
|
||||||
@@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
for (int32_t b = 0; b != batch_size; ++b) {
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
auto &hyps = cur[b];
|
auto &hyps = cur[b];
|
||||||
auto best_hyp = hyps.GetMostProbable(true);
|
auto best_hyp = hyps.GetMostProbable(true);
|
||||||
|
auto &r = (*result)[b];
|
||||||
|
|
||||||
(*result)[b].hyps = std::move(hyps);
|
r.hyps = std::move(hyps);
|
||||||
(*result)[b].tokens = std::move(best_hyp.ys);
|
r.tokens = std::move(best_hyp.ys);
|
||||||
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
|
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||||
|
r.frame_offset += num_frames;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() {
|
|||||||
auto result = recognizer_->GetResult(c->s.get());
|
auto result = recognizer_->GetResult(c->s.get());
|
||||||
|
|
||||||
asio::post(server_->GetConnectionContext(),
|
asio::post(server_->GetConnectionContext(),
|
||||||
[this, hdl = c->hdl, str = result.ToString()]() {
|
[this, hdl = c->hdl, str = result.AsJsonString()]() {
|
||||||
server_->Send(hdl, str);
|
server_->Send(hdl, str);
|
||||||
});
|
});
|
||||||
active_.erase(c->hdl);
|
active_.erase(c->hdl);
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ for a list of pre-trained models to download.
|
|||||||
recognizer.DecodeStream(s.get());
|
recognizer.DecodeStream(s.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string text = recognizer.GetResult(s.get()).text;
|
std::string text = recognizer.GetResult(s.get()).AsJsonString();
|
||||||
|
|
||||||
fprintf(stderr, "Done!\n");
|
fprintf(stderr, "Done!\n");
|
||||||
|
|
||||||
|
|||||||
@@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(
|
|||||||
sherpa_onnx::OnlineStream *s =
|
sherpa_onnx::OnlineStream *s =
|
||||||
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
|
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
|
||||||
sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);
|
sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);
|
||||||
return env->NewStringUTF(result.ToString().c_str());
|
return env->NewStringUTF(result.text.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_EXTERN_C
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
|||||||
Reference in New Issue
Block a user