Add timestamps for streaming ASR. (#123)
This commit is contained in:
@@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
||||
include(asio)
|
||||
endif()
|
||||
|
||||
include(json)
|
||||
|
||||
add_subdirectory(sherpa-onnx)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_C_API)
|
||||
|
||||
44
cmake/json.cmake
Normal file
44
cmake/json.cmake
Normal file
@@ -0,0 +1,44 @@
|
||||
function(download_json)
|
||||
include(FetchContent)
|
||||
|
||||
set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")
|
||||
set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")
|
||||
set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download json
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/json-3.11.2.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz
|
||||
${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz
|
||||
/tmp/json-3.11.2.tar.gz
|
||||
/star-fj/fangjun/download/github/json-3.11.2.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(json_URL "${f}")
|
||||
file(TO_CMAKE_PATH "${json_URL}" json_URL)
|
||||
set(json_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
FetchContent_Declare(json
|
||||
URL
|
||||
${json_URL}
|
||||
${json_URL2}
|
||||
URL_HASH ${json_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(json)
|
||||
if(NOT json_POPULATED)
|
||||
message(STATUS "Downloading json from ${json_URL}")
|
||||
FetchContent_Populate(json)
|
||||
endif()
|
||||
message(STATUS "json is downloaded to ${json_SOURCE_DIR}")
|
||||
include_directories(${json_SOURCE_DIR}/include)
|
||||
# Use #include "nlohmann/json.hpp"
|
||||
endfunction()
|
||||
|
||||
download_json()
|
||||
@@ -8,11 +8,13 @@
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||
@@ -22,16 +24,56 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
using json = nlohmann::json;
|
||||
json j;
|
||||
j["text"] = text;
|
||||
j["tokens"] = tokens;
|
||||
j["start_time"] = start_time;
|
||||
#if 1
|
||||
// This branch chooses number of decimal points to keep in
|
||||
// the return json string
|
||||
std::ostringstream os;
|
||||
os << "[";
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
j["timestamps"] = os.str();
|
||||
#else
|
||||
j["timestamps"] = timestamps;
|
||||
#endif
|
||||
|
||||
j["segment"] = segment;
|
||||
j["is_final"] = is_final;
|
||||
|
||||
return j.dump();
|
||||
}
|
||||
|
||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
std::string text;
|
||||
for (auto t : src.tokens) {
|
||||
text += sym_table[t];
|
||||
const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms,
|
||||
int32_t subsampling_factor) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
|
||||
for (auto i : src.tokens) {
|
||||
auto sym = sym_table[i];
|
||||
|
||||
r.text.append(sym);
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
|
||||
OnlineRecognizerResult ans;
|
||||
ans.text = std::move(text);
|
||||
return ans;
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
@@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
|
||||
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||
decoder_->StripLeadingBlanks(&decoder_result);
|
||||
|
||||
return Convert(decoder_result, sym_);
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
@@ -22,10 +23,45 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineRecognizerResult {
|
||||
/// Recognition results.
|
||||
/// For English, it consists of space separated words.
|
||||
/// For Chinese, it consists of Chinese words without spaces.
|
||||
/// Example 1: "hello world"
|
||||
/// Example 2: "你好世界"
|
||||
std::string text;
|
||||
|
||||
// TODO(fangjun): Add a method to return a json string
|
||||
std::string ToString() const { return text; }
|
||||
/// Decoded results at the token level.
|
||||
/// For instance, for BPE-based models it consists of a list of BPE tokens.
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
/// timestamps.size() == tokens.size()
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
|
||||
/// ID of this segment
|
||||
/// When an endpoint is detected, it is incremented
|
||||
int32_t segment = 0;
|
||||
|
||||
/// Starting frame of this segment.
|
||||
/// When an endpoint is detected, it will change
|
||||
float start_time = 0;
|
||||
|
||||
/// True if this is the last segment.
|
||||
bool is_final = false;
|
||||
|
||||
/** Return a json string.
|
||||
*
|
||||
* The returned string contains:
|
||||
* {
|
||||
* "text": "The recognition result",
|
||||
* "tokens": [x, x, x],
|
||||
* "timestamps": [x, x, x],
|
||||
* "segment": x,
|
||||
* "start_time": x,
|
||||
* "is_final": true|false
|
||||
* }
|
||||
*/
|
||||
std::string AsJsonString() const;
|
||||
};
|
||||
|
||||
struct OnlineRecognizerConfig {
|
||||
|
||||
@@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
|
||||
hyps = other.hyps;
|
||||
|
||||
frame_offset = other.frame_offset;
|
||||
timestamps = other.timestamps;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
@@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
|
||||
decoder_out = std::move(other.decoder_out);
|
||||
hyps = std::move(other.hyps);
|
||||
|
||||
frame_offset = other.frame_offset;
|
||||
timestamps = std::move(other.timestamps);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,18 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineTransducerDecoderResult {
|
||||
/// Number of frames after subsampling we have decoded so far
|
||||
int32_t frame_offset = 0;
|
||||
|
||||
/// The decoded token IDs so far
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// number of trailing blank frames decoded so far
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
// Cache decoder_out for endpointing
|
||||
Ort::Value decoder_out;
|
||||
|
||||
|
||||
@@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
|
||||
bool emitted = false;
|
||||
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
||||
auto &r = (*result)[i];
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
static_cast<const float *>(p_logit) + vocab_size)));
|
||||
if (y != 0) {
|
||||
emitted = true;
|
||||
(*result)[i].tokens.push_back(y);
|
||||
(*result)[i].num_trailing_blanks = 0;
|
||||
r.tokens.push_back(y);
|
||||
r.timestamps.push_back(t + r.frame_offset);
|
||||
r.num_trailing_blanks = 0;
|
||||
} else {
|
||||
++(*result)[i].num_trailing_blanks;
|
||||
++r.num_trailing_blanks;
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
@@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
}
|
||||
|
||||
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
|
||||
|
||||
// Update frame_offset
|
||||
for (auto &r : *result) {
|
||||
r.frame_offset += num_frames;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
||||
|
||||
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
|
||||
r->tokens = std::move(tokens);
|
||||
r->timestamps = std::move(hyp.timestamps);
|
||||
r->num_trailing_blanks = hyp.num_trailing_blanks;
|
||||
}
|
||||
|
||||
@@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
|
||||
for (int32_t b = 0; b < batch_size; ++b) {
|
||||
int32_t frame_offset = (*result)[b].frame_offset;
|
||||
int32_t start = hyps_num_split[b];
|
||||
int32_t end = hyps_num_split[b + 1];
|
||||
LogSoftmax(p_logit, vocab_size, (end - start));
|
||||
@@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
if (new_token != 0) {
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t + frame_offset);
|
||||
new_hyp.num_trailing_blanks = 0;
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
@@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(true);
|
||||
auto &r = (*result)[b];
|
||||
|
||||
(*result)[b].hyps = std::move(hyps);
|
||||
(*result)[b].tokens = std::move(best_hyp.ys);
|
||||
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||
r.hyps = std::move(hyps);
|
||||
r.tokens = std::move(best_hyp.ys);
|
||||
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||
r.frame_offset += num_frames;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() {
|
||||
auto result = recognizer_->GetResult(c->s.get());
|
||||
|
||||
asio::post(server_->GetConnectionContext(),
|
||||
[this, hdl = c->hdl, str = result.ToString()]() {
|
||||
[this, hdl = c->hdl, str = result.AsJsonString()]() {
|
||||
server_->Send(hdl, str);
|
||||
});
|
||||
active_.erase(c->hdl);
|
||||
|
||||
@@ -102,7 +102,7 @@ for a list of pre-trained models to download.
|
||||
recognizer.DecodeStream(s.get());
|
||||
}
|
||||
|
||||
std::string text = recognizer.GetResult(s.get()).text;
|
||||
std::string text = recognizer.GetResult(s.get()).AsJsonString();
|
||||
|
||||
fprintf(stderr, "Done!\n");
|
||||
|
||||
|
||||
@@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(
|
||||
sherpa_onnx::OnlineStream *s =
|
||||
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
|
||||
sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);
|
||||
return env->NewStringUTF(result.ToString().c_str());
|
||||
return env->NewStringUTF(result.text.c_str());
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
|
||||
Reference in New Issue
Block a user