Add isolate_tts demo (#1529)

This commit is contained in:
Aero
2024-11-13 00:04:57 +08:00
committed by GitHub
parent a16c9aff8b
commit 3f777b3fe3
5 changed files with 364 additions and 117 deletions

View File

@@ -0,0 +1,246 @@
import 'dart:io';
import 'dart:isolate';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:media_kit/media_kit.dart';
import 'package:path/path.dart' as p;
import 'package:path_provider/path_provider.dart';
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
import 'utils.dart';
class _IsolateTask<T> {
final SendPort sendPort;
RootIsolateToken? rootIsolateToken;
_IsolateTask(this.sendPort, this.rootIsolateToken);
}
class _PortModel {
final String method;
final SendPort? sendPort;
dynamic data;
_PortModel({
required this.method,
this.sendPort,
this.data,
});
}
class _TtsManager {
/// 主进程通信端口
final ReceivePort receivePort;
final Isolate isolate;
final SendPort isolatePort;
_TtsManager({
required this.receivePort,
required this.isolate,
required this.isolatePort,
});
}
class IsolateTts {
static late final _TtsManager _ttsManager;
/// 获取线程里的通信端口
static SendPort get _sendPort => _ttsManager.isolatePort;
static late sherpa_onnx.OfflineTts _tts;
static late Player _player;
static Future<void> init() async {
ReceivePort port = ReceivePort();
RootIsolateToken? rootIsolateToken = RootIsolateToken.instance;
Isolate isolate = await Isolate.spawn(
_isolateEntry,
_IsolateTask(port.sendPort, rootIsolateToken),
errorsAreFatal: false,
);
port.listen((msg) async {
if (msg is SendPort) {
print(11);
_ttsManager = _TtsManager(receivePort: port, isolate: isolate, isolatePort: msg);
return;
}
});
}
static Future<void> _isolateEntry(_IsolateTask task) async {
if (task.rootIsolateToken != null) {
BackgroundIsolateBinaryMessenger.ensureInitialized(task.rootIsolateToken!);
}
MediaKit.ensureInitialized();
_player = Player();
sherpa_onnx.initBindings();
final receivePort = ReceivePort();
task.sendPort.send(receivePort.sendPort);
String modelDir = '';
String modelName = '';
String ruleFsts = '';
String ruleFars = '';
String lexicon = '';
String dataDir = '';
String dictDir = '';
// Example 7
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-melo-tts-zh_en.tar.bz2
modelDir = 'vits-melo-tts-zh_en';
modelName = 'model.onnx';
lexicon = 'lexicon.txt';
dictDir = 'vits-melo-tts-zh_en/dict';
if (modelName == '') {
throw Exception('You are supposed to select a model by changing the code before you run the app');
}
final Directory directory = await getApplicationDocumentsDirectory();
modelName = p.join(directory.path, modelDir, modelName);
if (ruleFsts != '') {
final all = ruleFsts.split(',');
var tmp = <String>[];
for (final f in all) {
tmp.add(p.join(directory.path, f));
}
ruleFsts = tmp.join(',');
}
if (ruleFars != '') {
final all = ruleFars.split(',');
var tmp = <String>[];
for (final f in all) {
tmp.add(p.join(directory.path, f));
}
ruleFars = tmp.join(',');
}
if (lexicon != '') {
lexicon = p.join(directory.path, modelDir, lexicon);
}
if (dataDir != '') {
dataDir = p.join(directory.path, dataDir);
}
if (dictDir != '') {
dictDir = p.join(directory.path, dictDir);
}
final tokens = p.join(directory.path, modelDir, 'tokens.txt');
final vits = sherpa_onnx.OfflineTtsVitsModelConfig(
model: modelName,
lexicon: lexicon,
tokens: tokens,
dataDir: dataDir,
dictDir: dictDir,
);
final modelConfig = sherpa_onnx.OfflineTtsModelConfig(
vits: vits,
numThreads: 2,
debug: true,
provider: 'cpu',
);
final config = sherpa_onnx.OfflineTtsConfig(
model: modelConfig,
ruleFsts: ruleFsts,
ruleFars: ruleFars,
maxNumSenetences: 1,
);
// print(config);
receivePort.listen((msg) async {
print(msg);
if (msg is _PortModel) {
switch (msg.method) {
case 'generate':
{
_PortModel _v = msg;
final stopwatch = Stopwatch();
stopwatch.start();
final audio = _tts.generate(text: _v.data['text'], sid: _v.data['sid'], speed: _v.data['speed']);
final suffix = '-sid-${_v.data['sid']}-speed-${_v.data['sid'].toStringAsPrecision(2)}';
final filename = await generateWaveFilename(suffix);
final ok = sherpa_onnx.writeWave(
filename: filename,
samples: audio.samples,
sampleRate: audio.sampleRate,
);
if (ok) {
stopwatch.stop();
double elapsed = stopwatch.elapsed.inMilliseconds.toDouble();
double waveDuration = audio.samples.length.toDouble() / audio.sampleRate.toDouble();
print('Saved to\n$filename\n'
'Elapsed: ${(elapsed / 1000).toStringAsPrecision(4)} s\n'
'Wave duration: ${waveDuration.toStringAsPrecision(4)} s\n'
'RTF: ${(elapsed / 1000).toStringAsPrecision(4)}/${waveDuration.toStringAsPrecision(4)} '
'= ${(elapsed / 1000 / waveDuration).toStringAsPrecision(3)} ');
await _player.open(Media('file:///$filename'));
await _player.play();
}
}
break;
}
}
});
_tts = sherpa_onnx.OfflineTts(config);
}
static Future<void> generate({required String text, int sid = 0, double speed = 1.0}) async {
ReceivePort receivePort = ReceivePort();
_sendPort.send(_PortModel(
method: 'generate',
data: {'text': text, 'sid': sid, 'speed': speed},
sendPort: receivePort.sendPort,
));
await receivePort.first;
receivePort.close();
}
}
/// 这里是页面
class IsolateTtsView extends StatefulWidget {
const IsolateTtsView({super.key});
@override
State<IsolateTtsView> createState() => _IsolateTtsViewState();
}
class _IsolateTtsViewState extends State<IsolateTtsView> {
@override
void initState() {
super.initState();
IsolateTts.init();
}
@override
Widget build(BuildContext context) {
return Scaffold(
body: Center(
child: ElevatedButton(
onPressed: () {
IsolateTts.generate(text: '这是已退出的 isolate TTS');
},
child: Text('Isolate TTS'),
),
),
);
}
}

View File

@@ -1,8 +1,9 @@
// Copyright (c) 2024 Xiaomi Corporation // Copyright (c) 2024 Xiaomi Corporation
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import './tts.dart';
import './info.dart'; import './info.dart';
import './tts.dart';
import 'isolate_tts.dart';
void main() { void main() {
runApp(const MyApp()); runApp(const MyApp());
@@ -38,6 +39,7 @@ class _MyHomePageState extends State<MyHomePage> {
final List<Widget> _tabs = [ final List<Widget> _tabs = [
TtsScreen(), TtsScreen(),
InfoScreen(), InfoScreen(),
IsolateTtsView(),
]; ];
@override @override
Widget build(BuildContext context) { Widget build(BuildContext context) {
@@ -62,6 +64,10 @@ class _MyHomePageState extends State<MyHomePage> {
icon: Icon(Icons.info), icon: Icon(Icons.info),
label: 'Info', label: 'Info',
), ),
BottomNavigationBarItem(
icon: Icon(Icons.multiline_chart),
label: 'isolate',
),
], ],
), ),
); );

View File

@@ -79,17 +79,16 @@ Future<sherpa_onnx.OfflineTts> createOfflineTts() async {
// Example 7 // Example 7
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models // https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-melo-tts-zh_en.tar.bz2 // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-melo-tts-zh_en.tar.bz2
// modelDir = 'vits-melo-tts-zh_en'; modelDir = 'vits-melo-tts-zh_en';
// modelName = 'model.onnx'; modelName = 'model.onnx';
// lexicon = 'lexicon.txt'; lexicon = 'lexicon.txt';
// dictDir = 'vits-melo-tts-zh_en/dict'; dictDir = 'vits-melo-tts-zh_en/dict';
// ============================================================ // ============================================================
// Please don't change the remaining part of this function // Please don't change the remaining part of this function
// ============================================================ // ============================================================
if (modelName == '') { if (modelName == '') {
throw Exception( throw Exception('You are supposed to select a model by changing the code before you run the app');
'You are supposed to select a model by changing the code before you run the app');
} }
final Directory directory = await getApplicationDocumentsDirectory(); final Directory directory = await getApplicationDocumentsDirectory();

View File

@@ -77,9 +77,7 @@ class _TtsScreenState extends State<TtsScreen> {
onTapOutside: (PointerDownEvent event) { onTapOutside: (PointerDownEvent event) {
FocusManager.instance.primaryFocus?.unfocus(); FocusManager.instance.primaryFocus?.unfocus();
}, },
inputFormatters: <TextInputFormatter>[ inputFormatters: <TextInputFormatter>[FilteringTextInputFormatter.digitsOnly]),
FilteringTextInputFormatter.digitsOnly
]),
Slider( Slider(
// decoration: InputDecoration( // decoration: InputDecoration(
// labelText: "speech speed", // labelText: "speech speed",
@@ -108,125 +106,117 @@ class _TtsScreenState extends State<TtsScreen> {
}, },
), ),
const SizedBox(height: 5), const SizedBox(height: 5),
Row( Row(mainAxisAlignment: MainAxisAlignment.center, children: <Widget>[
mainAxisAlignment: MainAxisAlignment.center, OutlinedButton(
children: <Widget>[ child: Text("Generate"),
OutlinedButton( onPressed: () async {
child: Text("Generate"), await _init();
onPressed: () async { await _player?.stop();
await _init();
await _player?.stop();
setState(() { setState(() {
_maxSpeakerID = _tts?.numSpeakers ?? 0; _maxSpeakerID = _tts?.numSpeakers ?? 0;
if (_maxSpeakerID > 0) { if (_maxSpeakerID > 0) {
_maxSpeakerID -= 1; _maxSpeakerID -= 1;
} }
}); });
if (_tts == null) { if (_tts == null) {
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: 'Failed to initialize tts', text: 'Failed to initialize tts',
); );
return; return;
} }
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: '', text: '',
); );
final text = _controller_text_input.text.trim(); final text = _controller_text_input.text.trim();
if (text == '') { if (text == '') {
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: 'Please first input your text to generate', text: 'Please first input your text to generate',
); );
return; return;
} }
final sid = final sid = int.tryParse(_controller_sid.text.trim()) ?? 0;
int.tryParse(_controller_sid.text.trim()) ?? 0;
final stopwatch = Stopwatch(); final stopwatch = Stopwatch();
stopwatch.start(); stopwatch.start();
final audio = final audio = _tts!.generate(text: text, sid: sid, speed: _speed);
_tts!.generate(text: text, sid: sid, speed: _speed); final suffix = '-sid-$sid-speed-${_speed.toStringAsPrecision(2)}';
final suffix = final filename = await generateWaveFilename(suffix);
'-sid-$sid-speed-${_speed.toStringAsPrecision(2)}';
final filename = await generateWaveFilename(suffix);
final ok = sherpa_onnx.writeWave( final ok = sherpa_onnx.writeWave(
filename: filename, filename: filename,
samples: audio.samples, samples: audio.samples,
sampleRate: audio.sampleRate, sampleRate: audio.sampleRate,
); );
if (ok) { if (ok) {
stopwatch.stop(); stopwatch.stop();
double elapsed = double elapsed = stopwatch.elapsed.inMilliseconds.toDouble();
stopwatch.elapsed.inMilliseconds.toDouble();
double waveDuration = double waveDuration = audio.samples.length.toDouble() / audio.sampleRate.toDouble();
audio.samples.length.toDouble() /
audio.sampleRate.toDouble();
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: 'Saved to\n$filename\n' text: 'Saved to\n$filename\n'
'Elapsed: ${(elapsed / 1000).toStringAsPrecision(4)} s\n' 'Elapsed: ${(elapsed / 1000).toStringAsPrecision(4)} s\n'
'Wave duration: ${waveDuration.toStringAsPrecision(4)} s\n' 'Wave duration: ${waveDuration.toStringAsPrecision(4)} s\n'
'RTF: ${(elapsed / 1000).toStringAsPrecision(4)}/${waveDuration.toStringAsPrecision(4)} ' 'RTF: ${(elapsed / 1000).toStringAsPrecision(4)}/${waveDuration.toStringAsPrecision(4)} '
'= ${(elapsed / 1000 / waveDuration).toStringAsPrecision(3)} ', '= ${(elapsed / 1000 / waveDuration).toStringAsPrecision(3)} ',
); );
_lastFilename = filename; _lastFilename = filename;
await _player?.play(DeviceFileSource(_lastFilename)); await _player?.play(DeviceFileSource(_lastFilename));
} else { } else {
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: 'Failed to save generated audio', text: 'Failed to save generated audio',
); );
} }
}, },
), ),
const SizedBox(width: 5), const SizedBox(width: 5),
OutlinedButton( OutlinedButton(
child: Text("Clear"), child: Text("Clear"),
onPressed: () { onPressed: () {
_controller_text_input.value = TextEditingValue( _controller_text_input.value = TextEditingValue(
text: '', text: '',
); );
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: '', text: '',
); );
}, },
), ),
const SizedBox(width: 5), const SizedBox(width: 5),
OutlinedButton( OutlinedButton(
child: Text("Play"), child: Text("Play"),
onPressed: () async { onPressed: () async {
if (_lastFilename == '') { if (_lastFilename == '') {
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: 'No generated wave file found', text: 'No generated wave file found',
); );
return; return;
} }
await _player?.stop(); await _player?.stop();
await _player?.play(DeviceFileSource(_lastFilename)); await _player?.play(DeviceFileSource(_lastFilename));
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: 'Playing\n$_lastFilename', text: 'Playing\n$_lastFilename',
); );
}, },
), ),
const SizedBox(width: 5), const SizedBox(width: 5),
OutlinedButton( OutlinedButton(
child: Text("Stop"), child: Text("Stop"),
onPressed: () async { onPressed: () async {
await _player?.stop(); await _player?.stop();
_controller_hint.value = TextEditingValue( _controller_hint.value = TextEditingValue(
text: '', text: '',
); );
}, },
), ),
]), ]),
const SizedBox(height: 5), const SizedBox(height: 5),
TextField( TextField(
decoration: InputDecoration( decoration: InputDecoration(

View File

@@ -24,6 +24,12 @@ dependencies:
url_launcher: 6.2.6 url_launcher: 6.2.6
url_launcher_linux: 3.1.0 url_launcher_linux: 3.1.0
audioplayers: ^5.0.0 audioplayers: ^5.0.0
media_kit:
media_kit_libs_video:
flutter: flutter:
uses-material-design: true uses-material-design: true
assets:
- assets/vits-melo-tts-zh_en/
- assets/vits-melo-tts-zh_en/dict/