Add Dart API for streaming ASR (#933)

This commit is contained in:
Fangjun Kuang
2024-05-30 12:21:09 +08:00
committed by GitHub
parent 909148fe42
commit 49d66ec358
15 changed files with 887 additions and 16 deletions

View File

@@ -2,9 +2,8 @@
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import 'package:flutter/material.dart';
import "./speaker_identification_test.dart";
import "./vad_test.dart";
import './home.dart';
import './vad.dart';
import './streaming_asr.dart';
import './info.dart';
void main() {
@@ -20,7 +19,7 @@ class MyApp extends StatelessWidget {
theme: ThemeData(
primarySwatch: Colors.blue,
),
home: const MyHomePage(title: 'Next-gen Kaldi: VAD demo'),
home: const MyHomePage(title: 'Next-gen Kaldi Demo'),
);
}
}
@@ -35,7 +34,8 @@ class MyHomePage extends StatefulWidget {
class _MyHomePageState extends State<MyHomePage> {
int _currentIndex = 0;
final List<Widget> _tabs = [
HomeScreen(),
StreamingAsrScreen(),
VadScreen(),
InfoScreen(),
];
@override
@@ -52,10 +52,15 @@ class _MyHomePageState extends State<MyHomePage> {
_currentIndex = index;
});
},
// https://www.xiconeditor.com/
items: [
BottomNavigationBarItem(
icon: Icon(Icons.home),
label: 'Home',
icon: new Image.asset("assets/streaming-asr.ico"),
label: '',
),
BottomNavigationBarItem(
icon: new Image.asset("assets/vad.ico"),
label: '',
),
BottomNavigationBarItem(
icon: Icon(Icons.info),

View File

@@ -0,0 +1,259 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'dart:async';
import 'package:flutter/foundation.dart';
import 'package:flutter/material.dart';
import 'package:path/path.dart' as p;
import 'package:path_provider/path_provider.dart';
import 'package:record/record.dart';
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './utils.dart';
import './streaming_transducer_asr_test.dart'; // TODO(fangjun): remove it
Future<sherpa_onnx.OnlineRecognizer> createOnlineRecognizer() async {
var encoder =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx';
var decoder =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx';
var joiner =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx';
var tokens =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt';
encoder = await copyAssetFile(src: encoder, dst: 'encoder.onnx');
decoder = await copyAssetFile(src: decoder, dst: 'decoder.onnx');
joiner = await copyAssetFile(src: joiner, dst: 'joiner.onnx');
tokens = await copyAssetFile(src: tokens, dst: 'tokens.txt');
final transducer = sherpa_onnx.OnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner,
);
final modelConfig = sherpa_onnx.OnlineModelConfig(
transducer: transducer,
tokens: tokens,
modelType: 'zipformer',
);
final config = sherpa_onnx.OnlineRecognizerConfig(model: modelConfig);
return sherpa_onnx.OnlineRecognizer(config);
}
class StreamingAsrScreen extends StatefulWidget {
const StreamingAsrScreen({super.key});
@override
State<StreamingAsrScreen> createState() => _StreamingAsrScreenState();
}
class _StreamingAsrScreenState extends State<StreamingAsrScreen> {
late final TextEditingController _controller;
late final AudioRecorder _audioRecorder;
String _title = 'Real-time speech recognition';
String _last = '';
int _index = 0;
bool _isInitialized = false;
sherpa_onnx.OnlineRecognizer? _recognizer;
sherpa_onnx.OnlineStream? _stream;
int _sampleRate = 16000;
StreamSubscription<RecordState>? _recordSub;
RecordState _recordState = RecordState.stop;
@override
void initState() {
_audioRecorder = AudioRecorder();
_controller = TextEditingController();
_recordSub = _audioRecorder.onStateChanged().listen((recordState) {
_updateRecordState(recordState);
});
super.initState();
}
Future<void> _start() async {
if (!_isInitialized) {
sherpa_onnx.initBindings();
_recognizer = await createOnlineRecognizer();
_stream = _recognizer?.createStream();
_isInitialized = true;
}
try {
if (await _audioRecorder.hasPermission()) {
const encoder = AudioEncoder.pcm16bits;
if (!await _isEncoderSupported(encoder)) {
return;
}
final devs = await _audioRecorder.listInputDevices();
debugPrint(devs.toString());
const config = RecordConfig(
encoder: encoder,
sampleRate: 16000,
numChannels: 1,
);
final stream = await _audioRecorder.startStream(config);
stream.listen(
(data) {
final samplesFloat32 =
convertBytesToFloat32(Uint8List.fromList(data));
_stream!.acceptWaveform(
samples: samplesFloat32, sampleRate: _sampleRate);
while (_recognizer!.isReady(_stream!)) {
_recognizer!.decode(_stream!);
}
final text = _recognizer!.getResult(_stream!).text;
String textToDisplay = _last;
if (text != '') {
if (_last == '') {
textToDisplay = '$_index: $text';
} else {
textToDisplay = '$_index: $text\n$_last';
}
}
if (_recognizer!.isEndpoint(_stream!)) {
_recognizer!.reset(_stream!);
if (text != '') {
_last = textToDisplay;
_index += 1;
}
}
print('text: $textToDisplay');
_controller.value = TextEditingValue(
text: textToDisplay,
selection: TextSelection.collapsed(offset: textToDisplay.length),
);
},
onDone: () {
print('stream stopped.');
},
);
}
} catch (e) {
print(e);
}
}
Future<void> _stop() async {
_stream!.free();
_stream = _recognizer!.createStream();
await _audioRecorder.stop();
}
Future<void> _pause() => _audioRecorder.pause();
Future<void> _resume() => _audioRecorder.resume();
void _updateRecordState(RecordState recordState) {
setState(() => _recordState = recordState);
}
Future<bool> _isEncoderSupported(AudioEncoder encoder) async {
final isSupported = await _audioRecorder.isEncoderSupported(
encoder,
);
if (!isSupported) {
debugPrint('${encoder.name} is not supported on this platform.');
debugPrint('Supported encoders are:');
for (final e in AudioEncoder.values) {
if (await _audioRecorder.isEncoderSupported(e)) {
debugPrint('- ${encoder.name}');
}
}
}
return isSupported;
}
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
body: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
Text(_title),
const SizedBox(height: 50),
TextField(
maxLines: 5,
controller: _controller,
readOnly: true,
),
const SizedBox(height: 50),
Row(
mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
_buildRecordStopControl(),
const SizedBox(width: 20),
_buildText(),
],
),
],
),
),
);
}
@override
void dispose() {
_recordSub?.cancel();
_audioRecorder.dispose();
_stream?.free();
_recognizer?.free();
super.dispose();
}
Widget _buildRecordStopControl() {
late Icon icon;
late Color color;
if (_recordState != RecordState.stop) {
icon = const Icon(Icons.stop, color: Colors.red, size: 30);
color = Colors.red.withOpacity(0.1);
} else {
final theme = Theme.of(context);
icon = Icon(Icons.mic, color: theme.primaryColor, size: 30);
color = theme.primaryColor.withOpacity(0.1);
}
return ClipOval(
child: Material(
color: color,
child: InkWell(
child: SizedBox(width: 56, height: 56, child: icon),
onTap: () {
(_recordState != RecordState.stop) ? _stop() : _start();
},
),
),
);
}
Widget _buildText() {
if (_recordState == RecordState.stop) {
return const Text("Start");
} else {
return const Text("Stop");
}
}
}

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'package:path/path.dart';
import 'package:path_provider/path_provider.dart';
import 'package:flutter/services.dart' show rootBundle;
import 'dart:typed_data';
import "dart:io";
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './utils.dart';
Future<void> testStreamingTransducerAsr() async {
var encoder =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx';
var decoder =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx';
var joiner =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx';
var tokens =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt';
var testWave =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav';
encoder = await copyAssetFile(src: encoder, dst: 'encoder.onnx');
decoder = await copyAssetFile(src: decoder, dst: 'decoder.onnx');
joiner = await copyAssetFile(src: joiner, dst: 'joiner.onnx');
tokens = await copyAssetFile(src: tokens, dst: 'tokens.txt');
testWave = await copyAssetFile(src: testWave, dst: 'test.wav');
final transducer = sherpa_onnx.OnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner,
);
final modelConfig = sherpa_onnx.OnlineModelConfig(
transducer: transducer,
tokens: tokens,
modelType: 'zipformer',
);
final config = sherpa_onnx.OnlineRecognizerConfig(model: modelConfig);
print(config);
final recognizer = sherpa_onnx.OnlineRecognizer(config);
final waveData = sherpa_onnx.readWave(testWave);
final stream = recognizer.createStream();
stream.acceptWaveform(
samples: waveData.samples, sampleRate: waveData.sampleRate);
while (recognizer.isReady(stream)) {
recognizer.decode(stream);
}
final result = recognizer.getResult(stream);
print('result is: ${result}');
print('recognizer: ${recognizer.ptr}');
stream.free();
recognizer.free();
}

View File

@@ -11,14 +11,14 @@ import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './utils.dart';
class HomeScreen extends StatefulWidget {
const HomeScreen({super.key});
class VadScreen extends StatefulWidget {
const VadScreen({super.key});
@override
State<HomeScreen> createState() => _HomeScreenState();
State<VadScreen> createState() => _VadScreenState();
}
class _HomeScreenState extends State<HomeScreen> {
class _VadScreenState extends State<VadScreen> {
late final AudioRecorder _audioRecorder;
bool _printed = false;