Set is_final and start_time for online websocket server. (#342)
* Set is_final and start_time for online websocket server. * Convert timestamps to a json array
This commit is contained in:
@@ -174,8 +174,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
|
||||
include(asio)
|
||||
endif()
|
||||
|
||||
include(json)
|
||||
|
||||
add_subdirectory(sherpa-onnx)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_C_API)
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
function(download_json)
|
||||
include(FetchContent)
|
||||
|
||||
set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")
|
||||
set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")
|
||||
set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download json
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/json-3.11.2.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz
|
||||
${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz
|
||||
/tmp/json-3.11.2.tar.gz
|
||||
/star-fj/fangjun/download/github/json-3.11.2.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(json_URL "${f}")
|
||||
file(TO_CMAKE_PATH "${json_URL}" json_URL)
|
||||
message(STATUS "Found local downloaded json: ${json_URL}")
|
||||
set(json_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
FetchContent_Declare(json
|
||||
URL
|
||||
${json_URL}
|
||||
${json_URL2}
|
||||
URL_HASH ${json_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(json)
|
||||
if(NOT json_POPULATED)
|
||||
message(STATUS "Downloading json from ${json_URL}")
|
||||
FetchContent_Populate(json)
|
||||
endif()
|
||||
message(STATUS "json is downloaded to ${json_SOURCE_DIR}")
|
||||
include_directories(${json_SOURCE_DIR}/include)
|
||||
# Use #include "nlohmann/json.hpp"
|
||||
endfunction()
|
||||
|
||||
download_json()
|
||||
@@ -28,197 +28,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::string FixInvalidUtf8(const std::string &s) {
|
||||
int32_t s_size = s.size();
|
||||
|
||||
std::string ans;
|
||||
ans.reserve(s_size);
|
||||
|
||||
for (int32_t i = 0; i < s_size;) {
|
||||
uint8_t c = s[i];
|
||||
if (c < 0x80) {
|
||||
// valid
|
||||
ans.append(1, c);
|
||||
++i;
|
||||
continue;
|
||||
} else if ((c >= 0xc0) && (c < 0xe0)) {
|
||||
// beginning of two bytes
|
||||
if ((i + 1) > (s_size - 1)) {
|
||||
// no subsequent byte. invalid!
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
// valid 2-byte utf-8
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
i += 2;
|
||||
continue;
|
||||
} else if ((c >= 0xe0) && (c < 0xf0)) {
|
||||
// beginning of 3 bytes
|
||||
if ((i + 2) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
i += 3;
|
||||
continue;
|
||||
} else if ((c >= 0xf0) && (c < 0xf8)) {
|
||||
// 4 bytes
|
||||
if ((i + 3) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
i += 4;
|
||||
continue;
|
||||
} else if ((c >= 0xf8) && (c < 0xfc)) {
|
||||
// 5 bytes
|
||||
if ((i + 4) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
i += 5;
|
||||
continue;
|
||||
} else if ((c >= 0xfc) && (c < 0xfe)) {
|
||||
// 6 bytes
|
||||
if ((i + 5) > (s_size - 1)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next = s[i + 1];
|
||||
if (!(next >= 0x80 && next < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next2 = s[i + 2];
|
||||
if (!(next2 >= 0x80 && next2 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next3 = s[i + 3];
|
||||
if (!(next3 >= 0x80 && next3 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next4 = s[i + 4];
|
||||
if (!(next4 >= 0x80 && next4 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
uint8_t next5 = s[i + 5];
|
||||
if (!(next5 >= 0x80 && next5 < 0xc0)) {
|
||||
// invalid
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
ans.append(1, c);
|
||||
ans.append(1, next);
|
||||
ans.append(1, next2);
|
||||
ans.append(1, next3);
|
||||
ans.append(1, next4);
|
||||
ans.append(1, next5);
|
||||
i += 6;
|
||||
continue;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
@@ -235,19 +44,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
|
||||
r.tokens.push_back(s);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Fix the following error in offline-stream.cc
|
||||
//
|
||||
// j["text"] = text;
|
||||
|
||||
// libc++abi: terminating with uncaught exception of type
|
||||
// nlohmann::json_abi_v3_11_2::detail::type_error:
|
||||
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
|
||||
|
||||
#if 0
|
||||
r.text = FixInvalidUtf8(text);
|
||||
#else
|
||||
r.text = text;
|
||||
#endif
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -267,14 +267,14 @@ std::string OfflineRecognitionResult::AsJsonString() const {
|
||||
<< "timestamps"
|
||||
<< "\""
|
||||
<< ": ";
|
||||
os << "\"[";
|
||||
os << "[";
|
||||
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]\", ";
|
||||
os << "], ";
|
||||
|
||||
os << "\""
|
||||
<< "tokens"
|
||||
|
||||
@@ -28,9 +28,10 @@ namespace sherpa_onnx {
|
||||
|
||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms,
|
||||
float frame_shift_ms,
|
||||
int32_t subsampling_factor,
|
||||
int32_t segment) {
|
||||
int32_t segment,
|
||||
int32_t frames_since_start) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
@@ -49,6 +50,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
}
|
||||
|
||||
r.segment = segment;
|
||||
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||
|
||||
return r;
|
||||
}
|
||||
@@ -216,7 +218,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment());
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const override {
|
||||
|
||||
@@ -14,37 +14,61 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
using json = nlohmann::json;
|
||||
json j;
|
||||
j["text"] = text;
|
||||
j["tokens"] = tokens;
|
||||
j["start_time"] = start_time;
|
||||
#if 1
|
||||
// This branch chooses number of decimal points to keep in
|
||||
// the return json string
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
|
||||
os << "\"segment\":" << segment << ", ";
|
||||
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
|
||||
os << "\"text\""
|
||||
<< ": ";
|
||||
os << "\"" << text << "\""
|
||||
<< ", ";
|
||||
|
||||
os << "\""
|
||||
<< "timestamps"
|
||||
<< "\""
|
||||
<< ": ";
|
||||
os << "[";
|
||||
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
os << "\""
|
||||
<< "tokens"
|
||||
<< "\""
|
||||
<< ":";
|
||||
os << "[";
|
||||
|
||||
sep = "";
|
||||
auto oldFlags = os.flags();
|
||||
for (const auto &t : tokens) {
|
||||
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
|
||||
os << sep << "\""
|
||||
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
|
||||
<< ">"
|
||||
<< "\"";
|
||||
os.flags(oldFlags);
|
||||
} else {
|
||||
os << sep << "\"" << t << "\"";
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
j["timestamps"] = os.str();
|
||||
#else
|
||||
j["timestamps"] = timestamps;
|
||||
#endif
|
||||
os << "}";
|
||||
|
||||
j["segment"] = segment;
|
||||
j["is_final"] = is_final;
|
||||
|
||||
return j.dump();
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
|
||||
@@ -44,11 +44,11 @@ struct OnlineRecognizerResult {
|
||||
/// When an endpoint is detected, it is incremented
|
||||
int32_t segment = 0;
|
||||
|
||||
/// Starting frame of this segment.
|
||||
/// Starting time of this segment.
|
||||
/// When an endpoint is detected, it will change
|
||||
float start_time = 0;
|
||||
|
||||
/// True if this is the last segment.
|
||||
/// True if the end of this segment is reached
|
||||
bool is_final = false;
|
||||
|
||||
/** Return a json string.
|
||||
|
||||
@@ -43,6 +43,8 @@ class OnlineStream::Impl {
|
||||
|
||||
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
|
||||
|
||||
int32_t GetNumFramesSinceStart() const { return start_frame_index_; }
|
||||
|
||||
int32_t &GetCurrentSegment() { return segment_; }
|
||||
|
||||
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
|
||||
@@ -126,6 +128,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() {
|
||||
return impl_->GetNumProcessedFrames();
|
||||
}
|
||||
|
||||
int32_t OnlineStream::GetNumFramesSinceStart() const {
|
||||
return impl_->GetNumFramesSinceStart();
|
||||
}
|
||||
|
||||
int32_t &OnlineStream::GetCurrentSegment() {
|
||||
return impl_->GetCurrentSegment();
|
||||
}
|
||||
|
||||
@@ -66,7 +66,9 @@ class OnlineStream {
|
||||
// Initially, it is 0. It is always less than NumFramesReady().
|
||||
//
|
||||
// The returned reference is valid as long as this object is alive.
|
||||
int32_t &GetNumProcessedFrames();
|
||||
int32_t &GetNumProcessedFrames(); // It's reset after calling Reset()
|
||||
|
||||
int32_t GetNumFramesSinceStart() const;
|
||||
|
||||
int32_t &GetCurrentSegment();
|
||||
|
||||
|
||||
@@ -195,9 +195,14 @@ void OnlineWebsocketDecoder::Decode() {
|
||||
for (auto c : c_vec) {
|
||||
auto result = recognizer_->GetResult(c->s.get());
|
||||
if (recognizer_->IsEndpoint(c->s.get())) {
|
||||
result.is_final = true;
|
||||
recognizer_->Reset(c->s.get());
|
||||
}
|
||||
|
||||
if (!recognizer_->IsReady(c->s.get()) && c->eof) {
|
||||
result.is_final = true;
|
||||
}
|
||||
|
||||
asio::post(server_->GetConnectionContext(),
|
||||
[this, hdl = c->hdl, str = result.AsJsonString()]() {
|
||||
server_->Send(hdl, str);
|
||||
|
||||
Reference in New Issue
Block a user