Release GIL to support multithreading in websocket servers. (#451)

This commit is contained in:
Fangjun Kuang
2023-11-27 13:44:03 +08:00
committed by GitHub
parent 8dc08a9b97
commit 87a47d7db4
10 changed files with 87 additions and 47 deletions

View File

@@ -414,7 +414,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--max-batch-size", "--max-batch-size",
type=int, type=int,
default=25, default=3,
help="""Max batch size for computation. Note if there are not enough help="""Max batch size for computation. Note if there are not enough
requests in the queue, it will wait for max_wait_ms time. After that, requests in the queue, it will wait for max_wait_ms time. After that,
even if there are not enough requests, it still sends the even if there are not enough requests, it still sends the
@@ -459,7 +459,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--max-active-connections", "--max-active-connections",
type=int, type=int,
default=500, default=200,
help="""Maximum number of active connections. The server will refuse help="""Maximum number of active connections. The server will refuse
to accept new connections once the current number of active connections to accept new connections once the current number of active connections
equals to this limit. equals to this limit.
@@ -533,6 +533,7 @@ class NonStreamingServer:
self.certificate = certificate self.certificate = certificate
self.http_server = HttpServer(doc_root) self.http_server = HttpServer(doc_root)
self.nn_pool_size = nn_pool_size
self.nn_pool = ThreadPoolExecutor( self.nn_pool = ThreadPoolExecutor(
max_workers=nn_pool_size, max_workers=nn_pool_size,
thread_name_prefix="nn", thread_name_prefix="nn",
@@ -604,7 +605,9 @@ or <a href="/offline_record.html">/offline_record.html</a>
async def run(self, port: int): async def run(self, port: int):
logging.info("started") logging.info("started")
task = asyncio.create_task(self.stream_consumer_task()) tasks = []
for i in range(self.nn_pool_size):
tasks.append(asyncio.create_task(self.stream_consumer_task()))
if self.certificate: if self.certificate:
logging.info(f"Using certificate: {self.certificate}") logging.info(f"Using certificate: {self.certificate}")
@@ -636,7 +639,7 @@ or <a href="/offline_record.html">/offline_record.html</a>
await asyncio.Future() # run forever await asyncio.Future() # run forever
await task # not reachable await asyncio.gather(*tasks) # not reachable
async def recv_audio_samples( async def recv_audio_samples(
self, self,
@@ -722,6 +725,7 @@ or <a href="/offline_record.html">/offline_record.html</a>
batch.append(item) batch.append(item)
except asyncio.QueueEmpty: except asyncio.QueueEmpty:
pass pass
stream_list = [b[0] for b in batch] stream_list = [b[0] for b in batch]
future_list = [b[1] for b in batch] future_list = [b[1] for b in batch]

View File

@@ -296,7 +296,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--max-batch-size", "--max-batch-size",
type=int, type=int,
default=50, default=3,
help="""Max batch size for computation. Note if there are not enough help="""Max batch size for computation. Note if there are not enough
requests in the queue, it will wait for max_wait_ms time. After that, requests in the queue, it will wait for max_wait_ms time. After that,
even if there are not enough requests, it still sends the even if there are not enough requests, it still sends the
@@ -334,7 +334,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--max-active-connections", "--max-active-connections",
type=int, type=int,
default=500, default=200,
help="""Maximum number of active connections. The server will refuse help="""Maximum number of active connections. The server will refuse
to accept new connections once the current number of active connections to accept new connections once the current number of active connections
equals to this limit. equals to this limit.
@@ -478,6 +478,7 @@ class StreamingServer(object):
self.certificate = certificate self.certificate = certificate
self.http_server = HttpServer(doc_root) self.http_server = HttpServer(doc_root)
self.nn_pool_size = nn_pool_size
self.nn_pool = ThreadPoolExecutor( self.nn_pool = ThreadPoolExecutor(
max_workers=nn_pool_size, max_workers=nn_pool_size,
thread_name_prefix="nn", thread_name_prefix="nn",
@@ -591,7 +592,9 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a>
return status, header, response return status, header, response
async def run(self, port: int): async def run(self, port: int):
task = asyncio.create_task(self.stream_consumer_task()) tasks = []
for i in range(self.nn_pool_size):
tasks.append(asyncio.create_task(self.stream_consumer_task()))
if self.certificate: if self.certificate:
logging.info(f"Using certificate: {self.certificate}") logging.info(f"Using certificate: {self.certificate}")
@@ -629,7 +632,7 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a>
await asyncio.Future() # run forever await asyncio.Future() # run forever
await task # not reachable await asyncio.gather(*tasks) # not reachable
async def handle_connection( async def handle_connection(
self, self,

View File

@@ -19,10 +19,12 @@ void PybindCircularBuffer(py::module *m) {
[](PyClass &self, const std::vector<float> &samples) { [](PyClass &self, const std::vector<float> &samples) {
self.Push(samples.data(), samples.size()); self.Push(samples.data(), samples.size());
}, },
py::arg("samples")) py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def("get", &PyClass::Get, py::arg("start_index"), py::arg("n")) .def("get", &PyClass::Get, py::arg("start_index"), py::arg("n"),
.def("pop", &PyClass::Pop, py::arg("n")) py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset) .def("pop", &PyClass::Pop, py::arg("n"),
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("size", &PyClass::Size) .def_property_readonly("size", &PyClass::Size)
.def_property_readonly("head", &PyClass::Head) .def_property_readonly("head", &PyClass::Head)
.def_property_readonly("tail", &PyClass::Tail); .def_property_readonly("tail", &PyClass::Tail);

View File

@@ -41,19 +41,24 @@ void PybindOfflineRecognizer(py::module *m) {
using PyClass = OfflineRecognizer; using PyClass = OfflineRecognizer;
py::class_<PyClass>(*m, "OfflineRecognizer") py::class_<PyClass>(*m, "OfflineRecognizer")
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config")) .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
.def("create_stream", .def(
[](const PyClass &self) { return self.CreateStream(); }) "create_stream",
[](const PyClass &self) { return self.CreateStream(); },
py::call_guard<py::gil_scoped_release>())
.def( .def(
"create_stream", "create_stream",
[](PyClass &self, const std::string &hotwords) { [](PyClass &self, const std::string &hotwords) {
return self.CreateStream(hotwords); return self.CreateStream(hotwords);
}, },
py::arg("hotwords")) py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
.def("decode_stream", &PyClass::DecodeStream) .def("decode_stream", &PyClass::DecodeStream,
.def("decode_streams", py::call_guard<py::gil_scoped_release>())
[](const PyClass &self, std::vector<OfflineStream *> ss) { .def(
self.DecodeStreams(ss.data(), ss.size()); "decode_streams",
}); [](const PyClass &self, std::vector<OfflineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
},
py::call_guard<py::gil_scoped_release>());
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -50,9 +50,20 @@ void PybindOfflineStream(py::module *m) {
.def( .def(
"accept_waveform", "accept_waveform",
[](PyClass &self, float sample_rate, py::array_t<float> waveform) { [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
#if 0
auto report_gil_status = []() {
auto is_gil_held = false;
if (auto tstate = py::detail::get_thread_state_unchecked())
is_gil_held = (tstate == PyGILState_GetThisThreadState());
return is_gil_held ? "GIL held" : "GIL released";
};
std::cout << report_gil_status() << "\n";
#endif
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
}, },
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("result", &PyClass::GetResult); .def_property_readonly("result", &PyClass::GetResult);
} }

View File

@@ -45,7 +45,7 @@ void PybindOfflineTts(py::module *m) {
py::class_<PyClass>(*m, "OfflineTts") py::class_<PyClass>(*m, "OfflineTts")
.def(py::init<const OfflineTtsConfig &>(), py::arg("config")) .def(py::init<const OfflineTtsConfig &>(), py::arg("config"))
.def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0, .def("generate", &PyClass::Generate, py::arg("text"), py::arg("sid") = 0,
py::arg("speed") = 1.0); py::arg("speed") = 1.0, py::call_guard<py::gil_scoped_release>());
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -54,23 +54,31 @@ void PybindOnlineRecognizer(py::module *m) {
using PyClass = OnlineRecognizer; using PyClass = OnlineRecognizer;
py::class_<PyClass>(*m, "OnlineRecognizer") py::class_<PyClass>(*m, "OnlineRecognizer")
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config")) .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
.def("create_stream", .def(
[](const PyClass &self) { return self.CreateStream(); }) "create_stream",
[](const PyClass &self) { return self.CreateStream(); },
py::call_guard<py::gil_scoped_release>())
.def( .def(
"create_stream", "create_stream",
[](PyClass &self, const std::string &hotwords) { [](PyClass &self, const std::string &hotwords) {
return self.CreateStream(hotwords); return self.CreateStream(hotwords);
}, },
py::arg("hotwords")) py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
.def("is_ready", &PyClass::IsReady) .def("is_ready", &PyClass::IsReady,
.def("decode_stream", &PyClass::DecodeStream) py::call_guard<py::gil_scoped_release>())
.def("decode_streams", .def("decode_stream", &PyClass::DecodeStream,
[](PyClass &self, std::vector<OnlineStream *> ss) { py::call_guard<py::gil_scoped_release>())
self.DecodeStreams(ss.data(), ss.size()); .def(
}) "decode_streams",
.def("get_result", &PyClass::GetResult) [](PyClass &self, std::vector<OnlineStream *> ss) {
.def("is_endpoint", &PyClass::IsEndpoint) self.DecodeStreams(ss.data(), ss.size());
.def("reset", &PyClass::Reset); },
py::call_guard<py::gil_scoped_release>())
.def("get_result", &PyClass::GetResult,
py::call_guard<py::gil_scoped_release>())
.def("is_endpoint", &PyClass::IsEndpoint,
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>());
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -28,8 +28,10 @@ void PybindOnlineStream(py::module *m) {
[](PyClass &self, float sample_rate, py::array_t<float> waveform) { [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
}, },
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
.def("input_finished", &PyClass::InputFinished); py::call_guard<py::gil_scoped_release>())
.def("input_finished", &PyClass::InputFinished,
py::call_guard<py::gil_scoped_release>());
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -13,17 +13,21 @@ namespace sherpa_onnx {
void PybindVadModel(py::module *m) { void PybindVadModel(py::module *m) {
using PyClass = VadModel; using PyClass = VadModel;
py::class_<PyClass>(*m, "VadModel") py::class_<PyClass>(*m, "VadModel")
.def_static("create", &PyClass::Create, py::arg("config")) .def_static("create", &PyClass::Create, py::arg("config"),
.def("reset", &PyClass::Reset) py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def( .def(
"is_speech", "is_speech",
[](PyClass &self, const std::vector<float> &samples) -> bool { [](PyClass &self, const std::vector<float> &samples) -> bool {
return self.IsSpeech(samples.data(), samples.size()); return self.IsSpeech(samples.data(), samples.size());
}, },
py::arg("samples")) py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def("window_size", &PyClass::WindowSize) .def("window_size", &PyClass::WindowSize,
.def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples) py::call_guard<py::gil_scoped_release>())
.def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples); .def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples,
py::call_guard<py::gil_scoped_release>())
.def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples,
py::call_guard<py::gil_scoped_release>());
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -30,11 +30,12 @@ void PybindVoiceActivityDetector(py::module *m) {
[](PyClass &self, const std::vector<float> &samples) { [](PyClass &self, const std::vector<float> &samples) {
self.AcceptWaveform(samples.data(), samples.size()); self.AcceptWaveform(samples.data(), samples.size());
}, },
py::arg("samples")) py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def("empty", &PyClass::Empty) .def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
.def("pop", &PyClass::Pop) .def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
.def("is_speech_detected", &PyClass::IsSpeechDetected) .def("is_speech_detected", &PyClass::IsSpeechDetected,
.def("reset", &PyClass::Reset) py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("front", &PyClass::Front); .def_property_readonly("front", &PyClass::Front);
} }