Add timestamps for streaming ASR. (#123)

This commit is contained in:
Fangjun Kuang
2023-04-19 16:02:37 +08:00
committed by GitHub
parent 4b5d2887cb
commit ad05f52666
11 changed files with 170 additions and 19 deletions

View File

@@ -8,11 +8,13 @@
#include <assert.h>
#include <algorithm>
#include <iomanip>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "nlohmann/json.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
@@ -22,16 +24,56 @@
namespace sherpa_onnx {
std::string OnlineRecognizerResult::AsJsonString() const {
using json = nlohmann::json;
json j;
j["text"] = text;
j["tokens"] = tokens;
j["start_time"] = start_time;
#if 1
// This branch chooses number of decimal points to keep in
// the return json string
std::ostringstream os;
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "]";
j["timestamps"] = os.str();
#else
j["timestamps"] = timestamps;
#endif
j["segment"] = segment;
j["is_final"] = is_final;
return j.dump();
}
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table) {
std::string text;
for (auto t : src.tokens) {
text += sym_table[t];
const SymbolTable &sym_table,
int32_t frame_shift_ms,
int32_t subsampling_factor) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
for (auto i : src.tokens) {
auto sym = sym_table[i];
r.text.append(sym);
r.tokens.push_back(std::move(sym));
}
OnlineRecognizerResult ans;
ans.text = std::move(text);
return ans;
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
r.timestamps.push_back(time);
}
return r;
}
void OnlineRecognizerConfig::Register(ParseOptions *po) {
@@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);
return Convert(decoder_result, sym_);
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
}
bool IsEndpoint(OnlineStream *s) const {