Add vad clear api for better performance (#366)
* Add vad clear api for better performance * rename to make naming consistent and remove macro * Fix linker error * Fix Vad.kt
This commit is contained in:
@@ -161,9 +161,9 @@ class MainActivity : AppCompatActivity() {
|
||||
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
|
||||
|
||||
vad.acceptWaveform(samples)
|
||||
while(!vad.empty()) {vad.pop();}
|
||||
|
||||
val isSpeechDetected = vad.isSpeechDetected()
|
||||
vad.clear()
|
||||
|
||||
runOnUiThread {
|
||||
onVad(isSpeechDetected)
|
||||
|
||||
@@ -46,6 +46,8 @@ class Vad(
|
||||
// [start: Int, samples: FloatArray]
|
||||
fun front() = front(ptr)
|
||||
|
||||
fun clear() = clear(ptr)
|
||||
|
||||
fun isSpeechDetected(): Boolean = isSpeechDetected(ptr)
|
||||
|
||||
fun reset() = reset(ptr)
|
||||
@@ -64,6 +66,7 @@ class Vad(
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray)
|
||||
private external fun empty(ptr: Long): Boolean
|
||||
private external fun pop(ptr: Long)
|
||||
private external fun clear(ptr: Long)
|
||||
private external fun front(ptr: Long): Array<Any>
|
||||
private external fun isSpeechDetected(ptr: Long): Boolean
|
||||
private external fun reset(ptr: Long)
|
||||
|
||||
@@ -493,12 +493,17 @@ int32_t SherpaOnnxVoiceActivityDetectorDetected(
|
||||
return p->impl->IsSpeechDetected();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorPop(
|
||||
void SherpaOnnxVoiceActivityDetectorPop(
|
||||
SherpaOnnxVoiceActivityDetector *p) {
|
||||
p->impl->Pop();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_API const SherpaOnnxSpeechSegment *
|
||||
void SherpaOnnxVoiceActivityDetectorClear(
|
||||
SherpaOnnxVoiceActivityDetector *p) {
|
||||
p->impl->Clear();
|
||||
}
|
||||
|
||||
const SherpaOnnxSpeechSegment *
|
||||
SherpaOnnxVoiceActivityDetectorFront(SherpaOnnxVoiceActivityDetector *p) {
|
||||
const sherpa_onnx::SpeechSegment &segment = p->impl->Front();
|
||||
|
||||
|
||||
@@ -580,6 +580,10 @@ SherpaOnnxVoiceActivityDetectorDetected(SherpaOnnxVoiceActivityDetector *p);
|
||||
SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorPop(
|
||||
SherpaOnnxVoiceActivityDetector *p);
|
||||
|
||||
// Clear current speech segments.
|
||||
SHERPA_ONNX_API void SherpaOnnxVoiceActivityDetectorClear(
|
||||
SherpaOnnxVoiceActivityDetector *p);
|
||||
|
||||
// Return the first speech segment.
|
||||
// The user has to use SherpaOnnxDestroySpeechSegment() to free the returned
|
||||
// pointer to avoid memory leak.
|
||||
|
||||
@@ -76,6 +76,8 @@ class VoiceActivityDetector::Impl {
|
||||
|
||||
void Pop() { segments_.pop(); }
|
||||
|
||||
void Clear() { std::queue<SpeechSegment>().swap(segments_); }
|
||||
|
||||
const SpeechSegment &Front() const { return segments_.front(); }
|
||||
|
||||
void Reset() {
|
||||
@@ -121,6 +123,8 @@ bool VoiceActivityDetector::Empty() const { return impl_->Empty(); }
|
||||
|
||||
void VoiceActivityDetector::Pop() { impl_->Pop(); }
|
||||
|
||||
void VoiceActivityDetector::Clear() { impl_->Clear(); }
|
||||
|
||||
const SpeechSegment &VoiceActivityDetector::Front() const {
|
||||
return impl_->Front();
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ class VoiceActivityDetector {
|
||||
void AcceptWaveform(const float *samples, int32_t n);
|
||||
bool Empty() const;
|
||||
void Pop();
|
||||
void Clear();
|
||||
const SpeechSegment &Front() const;
|
||||
|
||||
bool IsSpeechDetected() const;
|
||||
|
||||
@@ -124,6 +124,8 @@ class SherpaOnnxVad {
|
||||
|
||||
void Pop() { vad_.Pop(); }
|
||||
|
||||
void Clear() { vad_.Clear();}
|
||||
|
||||
const SpeechSegment &Front() const { return vad_.Front(); }
|
||||
|
||||
bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); }
|
||||
@@ -556,6 +558,14 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
|
||||
model->Pop();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
|
||||
model->Clear();
|
||||
}
|
||||
|
||||
// see
|
||||
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
|
||||
static jobject NewInteger(JNIEnv *env, int32_t value) {
|
||||
|
||||
@@ -551,7 +551,7 @@ class SherpaOnnxVoiceActivityDetectorWrapper {
|
||||
return SherpaOnnxVoiceActivityDetectorEmpty(vad) == 1
|
||||
}
|
||||
|
||||
func isDetected() -> Bool {
|
||||
func isSpeechDetected() -> Bool {
|
||||
return SherpaOnnxVoiceActivityDetectorDetected(vad) == 1
|
||||
}
|
||||
|
||||
@@ -559,6 +559,10 @@ class SherpaOnnxVoiceActivityDetectorWrapper {
|
||||
SherpaOnnxVoiceActivityDetectorPop(vad)
|
||||
}
|
||||
|
||||
func clear() {
|
||||
SherpaOnnxVoiceActivityDetectorClear(vad)
|
||||
}
|
||||
|
||||
func front() -> SherpaOnnxSpeechSegmentWrapper {
|
||||
let p: UnsafePointer<SherpaOnnxSpeechSegment>? = SherpaOnnxVoiceActivityDetectorFront(vad)
|
||||
return SherpaOnnxSpeechSegmentWrapper(p: p)
|
||||
|
||||
@@ -174,32 +174,31 @@ func run() {
|
||||
|
||||
var segments: [SpeechSegment] = []
|
||||
|
||||
while array.count > windowSize {
|
||||
// todo(fangjun): avoid extra copies here
|
||||
vad.acceptWaveform(samples: [Float](array[0..<windowSize]))
|
||||
array = [Float](array[windowSize..<array.count])
|
||||
|
||||
while !vad.isEmpty() {
|
||||
let s = vad.front()
|
||||
vad.pop()
|
||||
let result = recognizer.decode(samples: s.samples)
|
||||
|
||||
segments.append(
|
||||
SpeechSegment(
|
||||
start: Float(s.start) / Float(sampleRate),
|
||||
duration: Float(s.samples.count) / Float(sampleRate),
|
||||
text: result.text))
|
||||
|
||||
print(segments.last!)
|
||||
|
||||
}
|
||||
for offset in stride(from: 0, to: array.count, by: windowSize) {
|
||||
let end = min(offset + windowSize, array.count)
|
||||
vad.acceptWaveform(samples: [Float](array[offset ..< end]))
|
||||
}
|
||||
|
||||
let srt = zip(segments.indices, segments).map { (index, element) in
|
||||
var index: Int = 0
|
||||
while !vad.isEmpty() {
|
||||
let s = vad.front()
|
||||
vad.pop()
|
||||
let result = recognizer.decode(samples: s.samples)
|
||||
|
||||
segments.append(
|
||||
SpeechSegment(
|
||||
start: Float(s.start) / Float(sampleRate),
|
||||
duration: Float(s.samples.count) / Float(sampleRate),
|
||||
text: result.text))
|
||||
|
||||
print(segments.last!)
|
||||
}
|
||||
|
||||
let srt: String = zip(segments.indices, segments).map { (index, element) in
|
||||
return "\(index+1)\n\(element)"
|
||||
}.joined(separator: "\n\n")
|
||||
|
||||
let srtFilename = filePath.stringByDeletingPathExtension + ".srt"
|
||||
let srtFilename: String = filePath.stringByDeletingPathExtension + ".srt"
|
||||
do {
|
||||
try srt.write(to: srtFilename.fileURL, atomically: true, encoding: .utf8)
|
||||
} catch {
|
||||
|
||||
Reference in New Issue
Block a user