Add timestamps for offline paraformer (#310)
This commit is contained in:
@@ -353,11 +353,22 @@ SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
|
||||
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
|
||||
const_cast<char *>(r->text)[text.size()] = 0;
|
||||
|
||||
if (!result.timestamps.empty()) {
|
||||
r->timestamps = new float[result.timestamps.size()];
|
||||
std::copy(result.timestamps.begin(), result.timestamps.end(),
|
||||
r->timestamps);
|
||||
r->count = result.timestamps.size();
|
||||
} else {
|
||||
r->timestamps = nullptr;
|
||||
r->count = 0;
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
void DestroyOfflineRecognizerResult(
|
||||
const SherpaOnnxOfflineRecognizerResult *r) {
|
||||
delete[] r->text;
|
||||
delete[] r->timestamps;
|
||||
delete r;
|
||||
}
|
||||
|
||||
@@ -408,6 +408,14 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
|
||||
const char *text;
|
||||
|
||||
// Pointer to continuous memory which holds timestamps
|
||||
//
|
||||
// It is NULL if the model does not support timestamps
|
||||
float *timestamps;
|
||||
|
||||
// number of entries in timestamps
|
||||
int32_t count;
|
||||
// TODO(fangjun): Add more fields
|
||||
} SherpaOnnxOfflineRecognizerResult;
|
||||
|
||||
|
||||
@@ -14,6 +14,11 @@ namespace sherpa_onnx {
|
||||
struct OfflineParaformerDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
// it contains the start time of each token in seconds
|
||||
//
|
||||
// len(timestamps) == len(tokens)
|
||||
std::vector<float> timestamps;
|
||||
};
|
||||
|
||||
class OfflineParaformerDecoder {
|
||||
@@ -28,7 +33,8 @@ class OfflineParaformerDecoder {
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineParaformerDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value token_num) = 0;
|
||||
Ort::Value log_probs, Ort::Value token_num,
|
||||
Ort::Value us_cif_peak = Ort::Value(nullptr)) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -5,13 +5,18 @@
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineParaformerDecoderResult>
|
||||
OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
|
||||
Ort::Value /*token_num*/) {
|
||||
OfflineParaformerGreedySearchDecoder::Decode(
|
||||
Ort::Value log_probs, Ort::Value /*token_num*/,
|
||||
Ort::Value us_cif_peak /*=Ort::Value(nullptr)*/
|
||||
) {
|
||||
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = shape[0];
|
||||
int32_t num_tokens = shape[1];
|
||||
@@ -25,12 +30,43 @@ OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs,
|
||||
for (int32_t k = 0; k != num_tokens; ++k) {
|
||||
auto max_idx = static_cast<int64_t>(
|
||||
std::distance(p, std::max_element(p, p + vocab_size)));
|
||||
if (max_idx == eos_id_) break;
|
||||
if (max_idx == eos_id_) {
|
||||
break;
|
||||
}
|
||||
|
||||
results[i].tokens.push_back(max_idx);
|
||||
|
||||
p += vocab_size;
|
||||
}
|
||||
|
||||
if (us_cif_peak) {
|
||||
int32_t dim = us_cif_peak.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
|
||||
const auto *peak = us_cif_peak.GetTensorData<float>() + i * dim;
|
||||
std::vector<float> timestamps;
|
||||
timestamps.reserve(results[i].tokens.size());
|
||||
|
||||
// 10.0: frameshift is 10 milliseconds
|
||||
// 6: LfrWindowSize
|
||||
// 3: us_cif_peak is upsampled by a factor of 3
|
||||
// 1000: milliseconds to seconds
|
||||
float scale = 10.0 * 6 / 3 / 1000;
|
||||
|
||||
for (int32_t k = 0; k != dim; ++k) {
|
||||
if (peak[k] > 1 - 1e-4) {
|
||||
timestamps.push_back(k * scale);
|
||||
}
|
||||
}
|
||||
timestamps.pop_back();
|
||||
|
||||
if (timestamps.size() == results[i].tokens.size()) {
|
||||
results[i].timestamps = std::move(timestamps);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
|
||||
static_cast<int32_t>(results[i].tokens.size()),
|
||||
static_cast<int32_t>(timestamps.size()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
|
||||
@@ -17,7 +17,8 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
|
||||
: eos_id_(eos_id) {}
|
||||
|
||||
std::vector<OfflineParaformerDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value /*token_num*/) override;
|
||||
Ort::Value log_probs, Ort::Value token_num,
|
||||
Ort::Value us_cif_peak = Ort::Value(nullptr)) override;
|
||||
|
||||
private:
|
||||
int32_t eos_id_;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@@ -36,16 +37,13 @@ class OfflineParaformerModel::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return {std::move(out[0]), std::move(out[1])};
|
||||
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
@@ -119,7 +117,7 @@ OfflineParaformerModel::OfflineParaformerModel(AAssetManager *mgr,
|
||||
|
||||
OfflineParaformerModel::~OfflineParaformerModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
|
||||
std::vector<Ort::Value> OfflineParaformerModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -35,13 +34,17 @@ class OfflineParaformerModel {
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int32_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* @return Return a vector containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size)
|
||||
* - token_num: A 1-D tensor of shape (N, T') containing number
|
||||
* of valid tokens in each utterance. Its dtype is int64_t.
|
||||
* If it is a model supporting timestamps, then there are additional two
|
||||
* outputs:
|
||||
* - us_alphas
|
||||
* - us_cif_peak
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length);
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length);
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
|
||||
@@ -31,6 +31,7 @@ static OfflineRecognitionResult Convert(
|
||||
const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps = src.timestamps;
|
||||
|
||||
std::string text;
|
||||
|
||||
@@ -184,7 +185,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
// i.e., -23.025850929940457f
|
||||
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> t{nullptr, nullptr};
|
||||
std::vector<Ort::Value> t;
|
||||
try {
|
||||
t = model_->Forward(std::move(x), std::move(x_length));
|
||||
} catch (const Ort::Exception &ex) {
|
||||
@@ -193,7 +194,13 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
return;
|
||||
}
|
||||
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
std::vector<OfflineParaformerDecoderResult> results;
|
||||
if (t.size() == 2) {
|
||||
results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
|
||||
} else {
|
||||
results =
|
||||
decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3]));
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
|
||||
Reference in New Issue
Block a user