Add Flutter text to speech demo (#1087)

This commit is contained in:
Fangjun Kuang
2024-07-08 11:23:11 +08:00
committed by GitHub
parent 1fe12c5107
commit e832d356c7
133 changed files with 6686 additions and 143 deletions

View File

@@ -0,0 +1,40 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'package:flutter/material.dart';
import 'package:url_launcher/url_launcher.dart';
class InfoScreen extends StatelessWidget {
@override
Widget build(BuildContext context) {
const double height = 20;
return Container(
child: Padding(
padding: const EdgeInsets.all(8.0),
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: <Widget>[
Text('Everything is open-sourced.'),
SizedBox(height: height),
InkWell(
child: Text('Code: https://github.com/k2-fsa/sherpa-onnx'),
onTap: () => launch('https://k2-fsa.github.io/sherpa/onnx/'),
),
SizedBox(height: height),
InkWell(
child: Text('Doc: https://k2-fsa.github.io/sherpa/onnx/'),
onTap: () => launch('https://k2-fsa.github.io/sherpa/onnx/'),
),
SizedBox(height: height),
Text('QQ 群: 744602236'),
SizedBox(height: height),
InkWell(
child: Text(
'微信群: https://k2-fsa.github.io/sherpa/social-groups.html'),
onTap: () =>
launch('https://k2-fsa.github.io/sherpa/social-groups.html'),
),
],
),
),
);
}
}

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'package:flutter/material.dart';
import './tts.dart';
import './info.dart';
void main() {
runApp(const MyApp());
}
class MyApp extends StatelessWidget {
const MyApp({super.key});
@override
Widget build(BuildContext context) {
return MaterialApp(
title: 'Next-gen Kaldi flutter demo',
theme: ThemeData(
colorScheme: ColorScheme.fromSeed(seedColor: Colors.deepPurple),
useMaterial3: true,
),
home: const MyHomePage(title: 'Next-gen Kaldi with Flutter'),
);
}
}
class MyHomePage extends StatefulWidget {
const MyHomePage({super.key, required this.title});
final String title;
@override
State<MyHomePage> createState() => _MyHomePageState();
}
class _MyHomePageState extends State<MyHomePage> {
int _currentIndex = 0;
final List<Widget> _tabs = [
TtsScreen(),
InfoScreen(),
];
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text(widget.title),
),
body: _tabs[_currentIndex],
bottomNavigationBar: BottomNavigationBar(
currentIndex: _currentIndex,
onTap: (int index) {
setState(() {
_currentIndex = index;
});
},
items: [
BottomNavigationBarItem(
icon: Icon(Icons.home),
label: 'Home',
),
BottomNavigationBarItem(
icon: Icon(Icons.info),
label: 'Info',
),
],
),
);
}
}

View File

@@ -0,0 +1,149 @@
// Copyright (c) 2024 Xiaomi Corporation
import "dart:io";
import 'package:flutter/services.dart';
import 'package:path_provider/path_provider.dart';
import 'package:path/path.dart' as p;
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './utils.dart';
Future<sherpa_onnx.OfflineTts> createOfflineTts() async {
// sherpa_onnx requires that model files are in the local disk, so we
// need to copy all asset files to disk.
await copyAllAssetFiles();
sherpa_onnx.initBindings();
// Such a design is to make it easier to build flutter APPs with
// github actions for a variety of tts models
//
// See https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/flutter/generate-tts.py
// for details
String modelDir = '';
String modelName = '';
String ruleFsts = '';
String ruleFars = '';
String lexicon = '';
String dataDir = '';
String dictDir = '';
// You can select an example below and change it according to match your
// selected tts model
// ============================================================
// Your change starts here
// ============================================================
// Example 1:
// modelDir = 'vits-vctk';
// modelName = 'vits-vctk.onnx';
// lexicon = 'lexicon.txt';
// Example 2:
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
// modelDir = 'vits-piper-en_US-amy-low';
// modelName = 'en_US-amy-low.onnx';
// dataDir = 'vits-piper-en_US-amy-low/espeak-ng-data';
// Example 3:
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2
// modelDir = 'vits-icefall-zh-aishell3';
// modelName = 'model.onnx';
// ruleFsts = 'vits-icefall-zh-aishell3/phone.fst,vits-icefall-zh-aishell3/date.fst,vits-icefall-zh-aishell3/number.fst,vits-icefall-zh-aishell3/new_heteronym.fst';
// ruleFars = 'vits-icefall-zh-aishell3/rule.far';
// lexicon = 'lexicon.txt';
// Example 4:
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/vits.html#csukuangfj-vits-zh-hf-fanchen-c-chinese-187-speakers
// modelDir = 'vits-zh-hf-fanchen-C';
// modelName = 'vits-zh-hf-fanchen-C.onnx';
// lexicon = 'lexicon.txt';
// dictDir = 'vits-zh-hf-fanchen-C/dict';
// Example 5:
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-coqui-de-css10.tar.bz2
// modelDir = 'vits-coqui-de-css10';
// modelName = 'model.onnx';
// Example 6
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-libritts_r-medium.tar.bz2
// modelDir = 'vits-piper-en_US-libritts_r-medium';
// modelName = 'en_US-libritts_r-medium.onnx';
// dataDir = 'vits-piper-en_US-libritts_r-medium/espeak-ng-data';
// ============================================================
// Please don't change the remaining part of this function
// ============================================================
if (modelName == '') {
throw Exception(
'You are supposed to select a model by changing the code before you run the app');
}
final Directory directory = await getApplicationDocumentsDirectory();
modelName = p.join(directory.path, modelDir, modelName);
if (ruleFsts != '') {
final all = ruleFsts.split(',');
var tmp = <String>[];
for (final f in all) {
tmp.add(p.join(directory.path, modelDir, f));
}
ruleFsts = tmp.join(',');
}
if (ruleFars != '') {
final all = ruleFars.split(',');
var tmp = <String>[];
for (final f in all) {
tmp.add(p.join(directory.path, f));
}
ruleFars = tmp.join(',');
}
if (lexicon != '') {
lexicon = p.join(directory.path, lexicon);
}
if (dataDir != '') {
dataDir = p.join(directory.path, dataDir);
}
if (dictDir != '') {
dictDir = p.join(directory.path, dictDir);
}
final tokens = p.join(directory.path, modelDir, 'tokens.txt');
final vits = sherpa_onnx.OfflineTtsVitsModelConfig(
model: modelName,
lexicon: lexicon,
tokens: tokens,
dataDir: dataDir,
dictDir: dictDir,
);
final modelConfig = sherpa_onnx.OfflineTtsModelConfig(
vits: vits,
numThreads: 2,
debug: true,
provider: 'cpu',
);
final config = sherpa_onnx.OfflineTtsConfig(
model: modelConfig,
ruleFsts: ruleFsts,
ruleFars: ruleFars,
maxNumSenetences: 1,
);
// print(config);
final tts = sherpa_onnx.OfflineTts(config);
print('tts created successfully');
return tts;
}

View File

@@ -0,0 +1,246 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'dart:async';
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'package:flutter/material.dart';
import 'package:audioplayers/audioplayers.dart';
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import './model.dart';
import './utils.dart';
class TtsScreen extends StatefulWidget {
const TtsScreen({super.key});
@override
State<TtsScreen> createState() => _TtsScreenState();
}
class _TtsScreenState extends State<TtsScreen> {
late final TextEditingController _controller_text_input;
late final TextEditingController _controller_sid;
late final TextEditingController _controller_hint;
late final AudioPlayer _player;
String _title = 'Text to speech';
String _lastFilename = '';
bool _isInitialized = false;
int _maxSpeakerID = 0;
double _speed = 1.0;
sherpa_onnx.OfflineTts? _tts;
@override
void initState() {
_controller_text_input = TextEditingController();
_controller_hint = TextEditingController();
_controller_sid = TextEditingController(text: '0');
super.initState();
}
Future<void> _init() async {
if (!_isInitialized) {
sherpa_onnx.initBindings();
_tts?.free();
_tts = await createOfflineTts();
_player = AudioPlayer();
_isInitialized = true;
}
}
@override
Widget build(BuildContext context) {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: Text(_title),
),
body: Padding(
padding: EdgeInsets.all(10),
child: Column(
// mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
TextField(
decoration: InputDecoration(
labelText: "Speaker ID (0-$_maxSpeakerID)",
hintText: 'Please input your speaker ID',
),
keyboardType: TextInputType.number,
maxLines: 1,
controller: _controller_sid,
inputFormatters: <TextInputFormatter>[
FilteringTextInputFormatter.digitsOnly
]),
Slider(
// decoration: InputDecoration(
// labelText: "speech speed",
// ),
label: "Speech speed ${_speed.toStringAsPrecision(2)}",
min: 0.5,
max: 3.0,
divisions: 25,
value: _speed,
onChanged: (value) {
setState(() {
_speed = value;
});
},
),
const SizedBox(height: 5),
TextField(
decoration: InputDecoration(
border: OutlineInputBorder(),
hintText: 'Please enter your text here',
),
maxLines: 5,
controller: _controller_text_input,
),
const SizedBox(height: 5),
Row(
mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
OutlinedButton(
child: Text("Generate"),
onPressed: () async {
await _init();
await _player?.stop();
setState(() {
_maxSpeakerID = _tts?.numSpeakers ?? 0;
if (_maxSpeakerID > 0) {
_maxSpeakerID -= 1;
}
});
if (_tts == null) {
_controller_hint.value = TextEditingValue(
text: 'Failed to initialize tts',
);
return;
}
_controller_hint.value = TextEditingValue(
text: '',
);
final text = _controller_text_input.text.trim();
if (text == '') {
_controller_hint.value = TextEditingValue(
text: 'Please first input your text to generate',
);
return;
}
final sid =
int.tryParse(_controller_sid.text.trim()) ?? 0;
final stopwatch = Stopwatch();
stopwatch.start();
final audio =
_tts!.generate(text: text, sid: sid, speed: _speed);
final suffix =
'-sid-$sid-speed-${_speed.toStringAsPrecision(2)}';
final filename = await generateWaveFilename(suffix);
final ok = sherpa_onnx.writeWave(
filename: filename,
samples: audio.samples,
sampleRate: audio.sampleRate,
);
if (ok) {
stopwatch.stop();
double elapsed =
stopwatch.elapsed.inMilliseconds.toDouble();
double waveDuration =
audio.samples.length.toDouble() /
audio.sampleRate.toDouble();
_controller_hint.value = TextEditingValue(
text: 'Saved to\n$filename\n'
'Elapsed: ${(elapsed / 1000).toStringAsPrecision(4)} s\n'
'Wave duration: ${waveDuration.toStringAsPrecision(4)} s\n'
'RTF: ${(elapsed / 1000).toStringAsPrecision(4)}/${waveDuration.toStringAsPrecision(4)} '
'= ${(elapsed / 1000 / waveDuration).toStringAsPrecision(3)} ',
);
_lastFilename = filename;
await _player?.play(DeviceFileSource(_lastFilename));
} else {
_controller_hint.value = TextEditingValue(
text: 'Failed to save generated audio',
);
}
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Clear"),
onPressed: () {
_controller_text_input.value = TextEditingValue(
text: '',
);
_controller_hint.value = TextEditingValue(
text: '',
);
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Play"),
onPressed: () async {
if (_lastFilename == '') {
_controller_hint.value = TextEditingValue(
text: 'No generated wave file found',
);
return;
}
await _player?.stop();
await _player?.play(DeviceFileSource(_lastFilename));
_controller_hint.value = TextEditingValue(
text: 'Playing\n$_lastFilename',
);
},
),
const SizedBox(width: 5),
OutlinedButton(
child: Text("Stop"),
onPressed: () async {
await _player?.stop();
_controller_hint.value = TextEditingValue(
text: '',
);
},
),
]),
const SizedBox(height: 5),
TextField(
decoration: InputDecoration(
border: OutlineInputBorder(),
hintText: 'Logs will be shown here.\n'
'The first run is slower due to model initialization.',
),
maxLines: 6,
controller: _controller_hint,
readOnly: true,
),
],
),
),
),
);
}
@override
void dispose() {
_tts?.free();
super.dispose();
}
}

View File

@@ -0,0 +1,56 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'dart:io';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:path/path.dart' as p;
import 'package:path_provider/path_provider.dart';
Future<String> generateWaveFilename([String suffix = '']) async {
final Directory directory = await getApplicationDocumentsDirectory();
DateTime now = DateTime.now();
final filename =
'${now.year.toString()}-${now.month.toString().padLeft(2, '0')}-${now.day.toString().padLeft(2, '0')}-${now.hour.toString().padLeft(2, '0')}-${now.minute.toString().padLeft(2, '0')}-${now.second.toString().padLeft(2, '0')}$suffix.wav';
return p.join(directory.path, filename);
}
// https://stackoverflow.com/questions/68862225/flutter-how-to-get-all-files-from-assets-folder-in-one-list
Future<List<String>> getAllAssetFiles() async {
final AssetManifest assetManifest =
await AssetManifest.loadFromAssetBundle(rootBundle);
final List<String> assets = assetManifest.listAssets();
return assets;
}
String stripLeadingDirectory(String src, {int n = 1}) {
return p.joinAll(p.split(src).sublist(n));
}
Future<void> copyAllAssetFiles() async {
final allFiles = await getAllAssetFiles();
for (final src in allFiles) {
final dst = stripLeadingDirectory(src);
await copyAssetFile(src, dst);
}
}
// Copy the asset file from src to dst.
// If dst already exists, then just skip the copy
Future<String> copyAssetFile(String src, [String? dst]) async {
final Directory directory = await getApplicationDocumentsDirectory();
if (dst == null) {
dst = p.basename(src);
}
final target = p.join(directory.path, dst);
bool exists = await new File(target).exists();
if (!exists) {
final data = await rootBundle.load(src);
final List<int> bytes =
data.buffer.asUint8List(data.offsetInBytes, data.lengthInBytes);
await (await File(target).create(recursive: true)).writeAsBytes(bytes);
}
return target;
}