Flutter demo for real-time speech recognition (#1042)

This commit is contained in:
Fangjun Kuang
2024-06-23 13:29:13 +08:00
committed by GitHub
parent 9dd0e03568
commit 169c9bf627
67 changed files with 4056 additions and 235 deletions

View File

@@ -0,0 +1,40 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'package:flutter/material.dart';
import 'package:url_launcher/url_launcher.dart';
class InfoScreen extends StatelessWidget {
@override
Widget build(BuildContext context) {
const double height = 20;
return Container(
child: Padding(
padding: const EdgeInsets.all(8.0),
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: <Widget>[
Text('Everything is open-sourced.'),
SizedBox(height: height),
InkWell(
child: Text('Code: https://github.com/k2-fsa/sherpa-onnx'),
onTap: () => launch('https://k2-fsa.github.io/sherpa/onnx/'),
),
SizedBox(height: height),
InkWell(
child: Text('Doc: https://k2-fsa.github.io/sherpa/onnx/'),
onTap: () => launch('https://k2-fsa.github.io/sherpa/onnx/'),
),
SizedBox(height: height),
Text('QQ 群: 744602236'),
SizedBox(height: height),
InkWell(
child: Text(
'微信群: https://k2-fsa.github.io/sherpa/social-groups.html'),
onTap: () =>
launch('https://k2-fsa.github.io/sherpa/social-groups.html'),
),
],
),
),
);
}
}

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'package:flutter/material.dart';
import './streaming_asr.dart';
import './info.dart';
void main() {
runApp(const MyApp());
}
class MyApp extends StatelessWidget {
const MyApp({super.key});
@override
Widget build(BuildContext context) {
return MaterialApp(
title: 'Next-gen Kaldi flutter demo',
theme: ThemeData(
colorScheme: ColorScheme.fromSeed(seedColor: Colors.deepPurple),
useMaterial3: true,
),
home: const MyHomePage(title: 'Next-gen Kaldi with Flutter'),
);
}
}
class MyHomePage extends StatefulWidget {
const MyHomePage({super.key, required this.title});
final String title;
@override
State<MyHomePage> createState() => _MyHomePageState();
}
class _MyHomePageState extends State<MyHomePage> {
int _currentIndex = 0;
final List<Widget> _tabs = [
StreamingAsrScreen(),
InfoScreen(),
];
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text(widget.title),
),
body: _tabs[_currentIndex],
bottomNavigationBar: BottomNavigationBar(
currentIndex: _currentIndex,
onTap: (int index) {
setState(() {
_currentIndex = index;
});
},
items: [
BottomNavigationBarItem(
icon: Icon(Icons.home),
label: 'Home',
),
BottomNavigationBarItem(
icon: Icon(Icons.info),
label: 'Info',
),
],
),
);
}
}

View File

@@ -0,0 +1,68 @@
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './utils.dart';
// Remember to change `assets` in ../pubspec.yaml
// and download files to ../assets
Future<sherpa_onnx.OnlineModelConfig> getOnlineModelConfig(
{required int type}) async {
switch (type) {
case 0:
final modelDir =
'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20';
return sherpa_onnx.OnlineModelConfig(
transducer: sherpa_onnx.OnlineTransducerModelConfig(
encoder:
await copyAssetFile('$modelDir/encoder-epoch-99-avg-1.int8.onnx'),
decoder: await copyAssetFile('$modelDir/decoder-epoch-99-avg-1.onnx'),
joiner: await copyAssetFile('$modelDir/joiner-epoch-99-avg-1.onnx'),
),
tokens: await copyAssetFile('$modelDir/tokens.txt'),
modelType: 'zipformer',
);
case 1:
final modelDir = 'assets/sherpa-onnx-streaming-zipformer-en-2023-06-26';
return sherpa_onnx.OnlineModelConfig(
transducer: sherpa_onnx.OnlineTransducerModelConfig(
encoder: await copyAssetFile(
'$modelDir/encoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx'),
decoder: await copyAssetFile(
'$modelDir/decoder-epoch-99-avg-1-chunk-16-left-128.onnx'),
joiner: await copyAssetFile(
'$modelDir/joiner-epoch-99-avg-1-chunk-16-left-128.onnx'),
),
tokens: await copyAssetFile('$modelDir/tokens.txt'),
modelType: 'zipformer2',
);
case 2:
final modelDir =
'assets/icefall-asr-zipformer-streaming-wenetspeech-20230615';
return sherpa_onnx.OnlineModelConfig(
transducer: sherpa_onnx.OnlineTransducerModelConfig(
encoder: await copyAssetFile(
'$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx'),
decoder: await copyAssetFile(
'$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx'),
joiner: await copyAssetFile(
'$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx'),
),
tokens: await copyAssetFile('$modelDir/data/lang_char/tokens.txt'),
modelType: 'zipformer2',
);
case 3:
final modelDir = 'assets/sherpa-onnx-streaming-zipformer-fr-2023-04-14';
return sherpa_onnx.OnlineModelConfig(
transducer: sherpa_onnx.OnlineTransducerModelConfig(
encoder: await copyAssetFile(
'$modelDir/encoder-epoch-29-avg-9-with-averaged-model.int8.onnx'),
decoder: await copyAssetFile(
'$modelDir/decoder-epoch-29-avg-9-with-averaged-model.onnx'),
joiner: await copyAssetFile(
'$modelDir/joincoder-epoch-29-avg-9-with-averaged-model.onnx'),
),
tokens: await copyAssetFile('$modelDir/tokens.txt'),
modelType: 'zipformer',
);
default:
throw ArgumentError('Unsupported type: $type');
}
}

View File

@@ -0,0 +1,241 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'dart:async';
import 'package:flutter/foundation.dart';
import 'package:flutter/material.dart';
import 'package:path/path.dart' as p;
import 'package:path_provider/path_provider.dart';
import 'package:record/record.dart';
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './utils.dart';
import './online_model.dart';
Future<sherpa_onnx.OnlineRecognizer> createOnlineRecognizer() async {
final type = 0;
final modelConfig = await getOnlineModelConfig(type: type);
final config = sherpa_onnx.OnlineRecognizerConfig(
model: modelConfig,
ruleFsts: '',
);
return sherpa_onnx.OnlineRecognizer(config);
}
class StreamingAsrScreen extends StatefulWidget {
const StreamingAsrScreen({super.key});
@override
State<StreamingAsrScreen> createState() => _StreamingAsrScreenState();
}
class _StreamingAsrScreenState extends State<StreamingAsrScreen> {
late final TextEditingController _controller;
late final AudioRecorder _audioRecorder;
String _title = 'Real-time speech recognition';
String _last = '';
int _index = 0;
bool _isInitialized = false;
sherpa_onnx.OnlineRecognizer? _recognizer;
sherpa_onnx.OnlineStream? _stream;
int _sampleRate = 16000;
StreamSubscription<RecordState>? _recordSub;
RecordState _recordState = RecordState.stop;
@override
void initState() {
_audioRecorder = AudioRecorder();
_controller = TextEditingController();
_recordSub = _audioRecorder.onStateChanged().listen((recordState) {
_updateRecordState(recordState);
});
super.initState();
}
Future<void> _start() async {
if (!_isInitialized) {
sherpa_onnx.initBindings();
_recognizer = await createOnlineRecognizer();
_stream = _recognizer?.createStream();
_isInitialized = true;
}
try {
if (await _audioRecorder.hasPermission()) {
const encoder = AudioEncoder.pcm16bits;
if (!await _isEncoderSupported(encoder)) {
return;
}
final devs = await _audioRecorder.listInputDevices();
debugPrint(devs.toString());
const config = RecordConfig(
encoder: encoder,
sampleRate: 16000,
numChannels: 1,
);
final stream = await _audioRecorder.startStream(config);
stream.listen(
(data) {
final samplesFloat32 =
convertBytesToFloat32(Uint8List.fromList(data));
_stream!.acceptWaveform(
samples: samplesFloat32, sampleRate: _sampleRate);
while (_recognizer!.isReady(_stream!)) {
_recognizer!.decode(_stream!);
}
final text = _recognizer!.getResult(_stream!).text;
String textToDisplay = _last;
if (text != '') {
if (_last == '') {
textToDisplay = '$_index: $text';
} else {
textToDisplay = '$_index: $text\n$_last';
}
}
if (_recognizer!.isEndpoint(_stream!)) {
_recognizer!.reset(_stream!);
if (text != '') {
_last = textToDisplay;
_index += 1;
}
}
// print('text: $textToDisplay');
_controller.value = TextEditingValue(
text: textToDisplay,
selection: TextSelection.collapsed(offset: textToDisplay.length),
);
},
onDone: () {
print('stream stopped.');
},
);
}
} catch (e) {
print(e);
}
}
Future<void> _stop() async {
_stream!.free();
_stream = _recognizer!.createStream();
await _audioRecorder.stop();
}
Future<void> _pause() => _audioRecorder.pause();
Future<void> _resume() => _audioRecorder.resume();
void _updateRecordState(RecordState recordState) {
setState(() => _recordState = recordState);
}
Future<bool> _isEncoderSupported(AudioEncoder encoder) async {
final isSupported = await _audioRecorder.isEncoderSupported(
encoder,
);
if (!isSupported) {
debugPrint('${encoder.name} is not supported on this platform.');
debugPrint('Supported encoders are:');
for (final e in AudioEncoder.values) {
if (await _audioRecorder.isEncoderSupported(e)) {
debugPrint('- ${encoder.name}');
}
}
}
return isSupported;
}
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: Text(_title),
),
body: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
const SizedBox(height: 50),
TextField(
maxLines: 5,
controller: _controller,
readOnly: true,
),
const SizedBox(height: 50),
Row(
mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
_buildRecordStopControl(),
const SizedBox(width: 20),
_buildText(),
],
),
],
),
),
);
}
@override
void dispose() {
_recordSub?.cancel();
_audioRecorder.dispose();
_stream?.free();
_recognizer?.free();
super.dispose();
}
Widget _buildRecordStopControl() {
late Icon icon;
late Color color;
if (_recordState != RecordState.stop) {
icon = const Icon(Icons.stop, color: Colors.red, size: 30);
color = Colors.red.withOpacity(0.1);
} else {
final theme = Theme.of(context);
icon = Icon(Icons.mic, color: theme.primaryColor, size: 30);
color = theme.primaryColor.withOpacity(0.1);
}
return ClipOval(
child: Material(
color: color,
child: InkWell(
child: SizedBox(width: 56, height: 56, child: icon),
onTap: () {
(_recordState != RecordState.stop) ? _stop() : _start();
},
),
),
);
}
Widget _buildText() {
if (_recordState == RecordState.stop) {
return const Text("Start");
} else {
return const Text("Stop");
}
}
}

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'package:path/path.dart';
import 'package:path_provider/path_provider.dart';
import 'package:flutter/services.dart' show rootBundle;
import 'dart:typed_data';
import "dart:io";
// Copy the asset file from src to dst
Future<String> copyAssetFile(String src, [String? dst]) async {
final Directory directory = await getApplicationDocumentsDirectory();
if (dst == null) {
dst = basename(src);
}
final target = join(directory.path, dst);
final data = await rootBundle.load(src);
final List<int> bytes =
data.buffer.asUint8List(data.offsetInBytes, data.lengthInBytes);
await File(target).writeAsBytes(bytes);
return target;
}
Float32List convertBytesToFloat32(Uint8List bytes, [endian = Endian.little]) {
final values = Float32List(bytes.length ~/ 2);
final data = ByteData.view(bytes.buffer);
for (var i = 0; i < bytes.length; i += 2) {
int short = data.getInt16(i, endian);
values[i ~/ 2] = short / 32678.0;
}
return values;
}