Add timestamps for streaming ASR. (#123)
This commit is contained in:
@@ -8,11 +8,13 @@
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||
@@ -22,16 +24,56 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
using json = nlohmann::json;
|
||||
json j;
|
||||
j["text"] = text;
|
||||
j["tokens"] = tokens;
|
||||
j["start_time"] = start_time;
|
||||
#if 1
|
||||
// This branch chooses number of decimal points to keep in
|
||||
// the return json string
|
||||
std::ostringstream os;
|
||||
os << "[";
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
j["timestamps"] = os.str();
|
||||
#else
|
||||
j["timestamps"] = timestamps;
|
||||
#endif
|
||||
|
||||
j["segment"] = segment;
|
||||
j["is_final"] = is_final;
|
||||
|
||||
return j.dump();
|
||||
}
|
||||
|
||||
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
std::string text;
|
||||
for (auto t : src.tokens) {
|
||||
text += sym_table[t];
|
||||
const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms,
|
||||
int32_t subsampling_factor) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
|
||||
for (auto i : src.tokens) {
|
||||
auto sym = sym_table[i];
|
||||
|
||||
r.text.append(sym);
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
|
||||
OnlineRecognizerResult ans;
|
||||
ans.text = std::move(text);
|
||||
return ans;
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
@@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
|
||||
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||
decoder_->StripLeadingBlanks(&decoder_result);
|
||||
|
||||
return Convert(decoder_result, sym_);
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const {
|
||||
|
||||
Reference in New Issue
Block a user