diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 08a9c4f0..d8f137f2 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -69,3 +69,11 @@ jobs: export EXE=sherpa-onnx .github/scripts/test-online-transducer.sh + + - name: Test online transducer (C API) + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=decode-file-c-api + + .github/scripts/test-online-transducer.sh diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 7f53317d..b3243b84 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -71,3 +71,11 @@ jobs: export EXE=sherpa-onnx .github/scripts/test-online-transducer.sh + + - name: Test online transducer (C API) + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=decode-file-c-api + + .github/scripts/test-online-transducer.sh diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 42e4f70d..76dbf799 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -78,3 +78,11 @@ jobs: export EXE=sherpa-onnx.exe .github/scripts/test-online-transducer.sh + + - name: Test online transducer (C API) + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=decode-file-c-api.exe + + .github/scripts/test-online-transducer.sh diff --git a/.gitignore b/.gitignore index fd87f07a..dc7d32ef 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,5 @@ a.txt run-bilingual*.sh run-*-zipformer.sh run-zh.sh +decode-file-c-api +run-decode-file-c-api.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bd853f4..ab57e664 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,7 @@ option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON) option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) +option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") @@ -91,3 +92,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) endif() add_subdirectory(sherpa-onnx) + +if(SHERPA_ONNX_ENABLE_C_API) + add_subdirectory(c-api-examples) +endif() diff --git a/c-api-examples/CMakeLists.txt b/c-api-examples/CMakeLists.txt new file mode 100644 index 00000000..02c4cc16 --- /dev/null +++ b/c-api-examples/CMakeLists.txt @@ -0,0 +1,3 @@ +include_directories(${CMAKE_SOURCE_DIR}) +add_executable(decode-file-c-api decode-file-c-api.c) +target_link_libraries(decode-file-c-api sherpa-onnx-c-api) diff --git a/c-api-examples/Makefile b/c-api-examples/Makefile new file mode 100644 index 00000000..6e54e242 --- /dev/null +++ b/c-api-examples/Makefile @@ -0,0 +1,10 @@ + +CFLAGS := -I ../ +LDFLAGS := -L ../build/lib +LDFLAGS += -L ../build/_deps/onnxruntime-src/lib +LDFLAGS += -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core +LDFLAGS += -Wl,-rpath,../build/lib +LDFLAGS += -Wl,-rpath,../build/_deps/onnxruntime-src/lib + +decode-file-c-api: decode-file-c-api.c + $(CC) $(CFLAGS) -o $@ $< $(LDFLAGS) diff --git a/c-api-examples/decode-file-c-api.c b/c-api-examples/decode-file-c-api.c new file mode 100644 index 00000000..ddd67fe9 --- /dev/null +++ b/c-api-examples/decode-file-c-api.c @@ -0,0 +1,113 @@ +// c-api-examples/decode-file-c-api.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include +#include +#include + +#include "sherpa-onnx/c-api/c-api.h" + +const char *kUsage = + "\n" + "Usage:\n " + " ./bin/decode-file-c-api \\\n" + " /path/to/tokens.txt \\\n" + " /path/to/encoder.onnx \\\n" + " /path/to/decoder.onnx \\\n" + " /path/to/joiner.onnx \\\n" + " /path/to/foo.wav [num_threads]\n" + "\n\n" + "Please refer to \n" + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html\n" + "for a list of pre-trained models to download.\n"; + +int32_t main(int32_t argc, char *argv[]) { + if (argc < 6 || argc > 7) { + fprintf(stderr, "%s\n", kUsage); + return -1; + } + SherpaOnnxOnlineRecognizerConfig config; + config.model_config.tokens = argv[1]; + config.model_config.encoder = argv[2]; + config.model_config.decoder = argv[3]; + config.model_config.joiner = argv[4]; + + int32_t num_threads = 4; + if (argc == 7 && atoi(argv[6]) > 0) { + num_threads = atoi(argv[6]); + } + config.model_config.num_threads = num_threads; + config.model_config.debug = 0; + + config.feat_config.sample_rate = 16000; + config.feat_config.feature_dim = 80; + + config.enable_endpoint = 1; + config.rule1_min_trailing_silence = 2.4; + config.rule2_min_trailing_silence = 1.2; + config.rule3_min_utterance_length = 300; + + SherpaOnnxOnlineRecognizer *recognizer = CreateOnlineRecognizer(&config); + SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer); + + const char *wav_filename = argv[5]; + FILE *fp = fopen(wav_filename, "rb"); + if (!fp) { + fprintf(stderr, "Failed to open %s\n", wav_filename); + return -1; + } + + // Assume the wave header occupies 44 bytes. + fseek(fp, 44, SEEK_SET); + + // simulate streaming + +#define N 3200 // 0.2 s. Sample rate is fixed to 16 kHz + + int16_t buffer[N]; + float samples[N]; + + while (!feof(fp)) { + size_t n = fread((void *)buffer, sizeof(int16_t), N, fp); + if (n > 0) { + for (size_t i = 0; i != n; ++i) { + samples[i] = buffer[i] / 32768.; + } + AcceptWaveform(stream, 16000, samples, n); + while (IsOnlineStreamReady(recognizer, stream)) { + DecodeOnlineStream(recognizer, stream); + } + + SherpaOnnxOnlineRecognizerResult *r = + GetOnlineStreamResult(recognizer, stream); + if (strlen(r->text)) { + fprintf(stderr, "%s\n", r->text); + } + DestroyOnlineRecognizerResult(r); + } + } + fclose(fp); + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + AcceptWaveform(stream, 16000, tail_paddings, 4800); + + InputFinished(stream); + while (IsOnlineStreamReady(recognizer, stream)) { + DecodeOnlineStream(recognizer, stream); + } + + SherpaOnnxOnlineRecognizerResult *r = + GetOnlineStreamResult(recognizer, stream); + if (strlen(r->text)) { + fprintf(stderr, "%s\n", r->text); + } + + DestroyOnlineRecognizerResult(r); + + DestoryOnlineStream(stream); + DestroyOnlineRecognizer(recognizer); + + return 0; +} diff --git a/sherpa-onnx/CMakeLists.txt b/sherpa-onnx/CMakeLists.txt index d641811a..c39763d4 100644 --- a/sherpa-onnx/CMakeLists.txt +++ b/sherpa-onnx/CMakeLists.txt @@ -6,3 +6,7 @@ endif() if(SHERPA_ONNX_ENABLE_JNI) add_subdirectory(jni) endif() + +if(SHERPA_ONNX_ENABLE_C_API) + add_subdirectory(c-api) +endif() diff --git a/sherpa-onnx/c-api/CMakeLists.txt b/sherpa-onnx/c-api/CMakeLists.txt new file mode 100644 index 00000000..95f98d92 --- /dev/null +++ b/sherpa-onnx/c-api/CMakeLists.txt @@ -0,0 +1,10 @@ +include_directories(${CMAKE_SOURCE_DIR}) +add_library(sherpa-onnx-c-api c-api.cc) +target_link_libraries(sherpa-onnx-c-api sherpa-onnx-core) + +install(TARGETS sherpa-onnx-c-api DESTINATION lib) + +install(FILES c-api.h + DESTINATION include/sherpa-onnx/c-api +) + diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc new file mode 100644 index 00000000..2c08a91d --- /dev/null +++ b/sherpa-onnx/c-api/c-api.cc @@ -0,0 +1,126 @@ +// sherpa-onnx/c-api/c-api.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/c-api/c-api.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/online-recognizer.h" + +struct SherpaOnnxOnlineRecognizer { + sherpa_onnx::OnlineRecognizer *impl; +}; + +struct SherpaOnnxOnlineStream { + std::unique_ptr impl; + explicit SherpaOnnxOnlineStream(std::unique_ptr p) + : impl(std::move(p)) {} +}; + +SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( + const SherpaOnnxOnlineRecognizerConfig *config) { + sherpa_onnx::OnlineRecognizerConfig recognizer_config; + + recognizer_config.feat_config.sampling_rate = config->feat_config.sample_rate; + recognizer_config.feat_config.feature_dim = config->feat_config.feature_dim; + + recognizer_config.model_config.encoder_filename = + config->model_config.encoder; + recognizer_config.model_config.decoder_filename = + config->model_config.decoder; + recognizer_config.model_config.joiner_filename = config->model_config.joiner; + recognizer_config.model_config.tokens = config->model_config.tokens; + recognizer_config.model_config.num_threads = config->model_config.num_threads; + recognizer_config.model_config.debug = config->model_config.debug; + + recognizer_config.enable_endpoint = config->enable_endpoint; + + recognizer_config.endpoint_config.rule1.min_trailing_silence = + config->rule1_min_trailing_silence; + + recognizer_config.endpoint_config.rule2.min_trailing_silence = + config->rule2_min_trailing_silence; + + recognizer_config.endpoint_config.rule3.min_utterance_length = + config->rule3_min_utterance_length; + + SherpaOnnxOnlineRecognizer *recognizer = new SherpaOnnxOnlineRecognizer; + recognizer->impl = new sherpa_onnx::OnlineRecognizer(recognizer_config); + + return recognizer; +} + +void DestroyOnlineRecognizer(SherpaOnnxOnlineRecognizer *recognizer) { + delete recognizer->impl; + delete recognizer; +} + +SherpaOnnxOnlineStream *CreateOnlineStream( + const SherpaOnnxOnlineRecognizer *recognizer) { + SherpaOnnxOnlineStream *stream = + new SherpaOnnxOnlineStream(recognizer->impl->CreateStream()); + return stream; +} + +void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; } + +void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, + const float *samples, int32_t n) { + stream->impl->AcceptWaveform(sample_rate, samples, n); +} + +int32_t IsOnlineStreamReady(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream) { + return recognizer->impl->IsReady(stream->impl.get()); +} + +void DecodeOnlineStream(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream) { + recognizer->impl->DecodeStream(stream->impl.get()); +} + +void DecodeMultipleOnlineStreams(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream **streams, int32_t n) { + std::vector ss(n); + for (int32_t i = 0; i != n; ++n) { + ss[i] = streams[i]->impl.get(); + } + recognizer->impl->DecodeStreams(ss.data(), n); +} + +SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( + SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) { + sherpa_onnx::OnlineRecognizerResult result = + recognizer->impl->GetResult(stream->impl.get()); + const auto &text = result.text; + + auto r = new SherpaOnnxOnlineRecognizerResult; + r->text = new char[text.size() + 1]; + std::copy(text.begin(), text.end(), const_cast(r->text)); + const_cast(r->text)[text.size()] = 0; + + return r; +} + +void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) { + delete[] r->text; + delete r; +} + +void Reset(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream) { + recognizer->impl->Reset(stream->impl.get()); +} + +void InputFinished(SherpaOnnxOnlineStream *stream) { + stream->impl->InputFinished(); +} + +int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream) { + return recognizer->impl->IsEndpoint(stream->impl.get()); +} diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h new file mode 100644 index 00000000..1732d4b3 --- /dev/null +++ b/sherpa-onnx/c-api/c-api.h @@ -0,0 +1,194 @@ +// sherpa-onnx/c-api/c-api.h +// +// Copyright (c) 2023 Xiaomi Corporation + +// C API for sherpa-onnx +// +// Please refer to +// https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c +// for usages. +// + +#ifndef SHERPA_ONNX_C_API_C_API_H_ +#define SHERPA_ONNX_C_API_C_API_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +/// to download pre-trained models. That is, you can find encoder-xxx.onnx +/// decoder-xxx.onnx, joiner-xxx.onnx, and tokens.txt for this struct +/// from there. +typedef struct SherpaOnnxOnlineTransducerModelConfig { + const char *encoder; + const char *decoder; + const char *joiner; + const char *tokens; + int32_t num_threads; + int32_t debug; // true to print debug information of the model +} SherpaOnnxOnlineTransducerModelConfig; + +/// It expects 16 kHz 16-bit single channel wave format. +typedef struct SherpaOnnxFeatureConfig { + /// Sample rate of the input data. MUST match the one expected + /// by the model. For instance, it should be 16000 for models provided + /// by us. + int32_t sample_rate; + + /// Feature dimension of the model. + /// For instance, it should be 80 for models provided by us. + int32_t feature_dim; +} SherpaOnnxFeatureConfig; + +typedef struct SherpaOnnxOnlineRecognizerConfig { + SherpaOnnxFeatureConfig feat_config; + SherpaOnnxOnlineTransducerModelConfig model_config; + + /// 0 to disable endpoint detection. + /// A non-zero value to enable endpoint detection. + int32_t enable_endpoint; + + /// An endpoint is detected if trailing silence in seconds is larger than + /// this value even if nothing has been decoded. + /// Used only when enable_endpoint is not 0. + float rule1_min_trailing_silence; + + /// An endpoint is detected if trailing silence in seconds is larger than + /// this value after something that is not blank has been decoded. + /// Used only when enable_endpoint is not 0. + float rule2_min_trailing_silence; + + /// An endpoint is detected if the utterance in seconds is larger than + /// this value. + /// Used only when enable_endpoint is not 0. + float rule3_min_utterance_length; +} SherpaOnnxOnlineRecognizerConfig; + +typedef struct SherpaOnnxOnlineRecognizerResult { + const char *text; + // TODO(fangjun): Add more fields +} SherpaOnnxOnlineRecognizerResult; + +/// Note: OnlineRecognizer here means StreamingRecognizer. +/// It does not need to access the Internet during recognition. +/// Everything is run locally. +typedef struct SherpaOnnxOnlineRecognizer SherpaOnnxOnlineRecognizer; +typedef struct SherpaOnnxOnlineStream SherpaOnnxOnlineStream; + +/// @param config Config for the recongizer. +/// @return Return a pointer to the recognizer. The user has to invoke +// DestroyOnlineRecognizer() to free it to avoid memory leak. +SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( + const SherpaOnnxOnlineRecognizerConfig *config); + +/// Free a pointer returned by CreateOnlineRecognizer() +/// +/// @param p A pointer returned by CreateOnlineRecognizer() +void DestroyOnlineRecognizer(SherpaOnnxOnlineRecognizer *recognizer); + +/// Create an online stream for accepting wave samples. +/// +/// @param recognizer A pointer returned by CreateOnlineRecognizer() +/// @return Return a pointer to an OnlineStream. The user has to invoke +/// DestoryOnlineStream() to free it to avoid memory leak. +SherpaOnnxOnlineStream *CreateOnlineStream( + const SherpaOnnxOnlineRecognizer *recognizer); + +/// Destory an online stream. +/// +/// @param stream A pointer returned by CreateOnlineStream() +void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); + +/// Accept input audio samples and compute the features. +/// The user has to invoke DecodeOnlineStream() to run the neural network and +/// decoding. +/// +/// @param stream A pointer returned by CreateOnlineStream(). +/// @param sample_rate Sampler rate of the input samples. It has to be 16 kHz +/// for models from icefall. +/// @param samples A pointer to a 1-D array containing audio samples. +/// The range of samples has to be normalized to [-1, 1]. +/// @param n Number of elements in the samples array. +void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, + const float *samples, int32_t n); + +/// Return 1 if there are enough number of feature frames for decoding. +/// Return 0 otherwise. +/// +/// @param recognizer A pointer returned by CreateOnlineRecognizer +/// @param stream A pointer returned by CreateOnlineStream +int32_t IsOnlineStreamReady(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream); + +/// Call this function to run the neural network model and decoding. +// +/// Precondition for this function: IsOnlineStreamReady() MUST return 1. +/// +/// Usage example: +/// +/// while (IsOnlineStreamReady(recognizer, stream)) { +/// DecodeOnlineStream(recognizer, stream); +/// } +/// +void DecodeOnlineStream(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream); + +/// This function is similar to DecodeOnlineStream(). It decodes multiple +/// OnlineStream in parallel. +/// +/// Caution: The caller has to ensure each OnlineStream is ready, i.e., +/// IsOnlineStreamReady() for that stream should return 1. +/// +/// @param recognizer A pointer returned by CreateOnlineRecognizer() +/// @param streams A pointer array containing pointers returned by +/// CreateOnlineRecognizer() +/// @param n Number of elements in the given streams array. +void DecodeMultipleOnlineStreams(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream **streams, int32_t n); + +/// Get the decoding results so far for an OnlineStream. +/// +/// @param recognizer A pointer returned by CreateOnlineRecognizer(). +/// @param stream A pointer returned by CreateOnlineStream(). +/// @return A pointer containing the result. The user has to invoke +/// DestroyOnlineRecognizerResult() to free the returned pointer to +/// avoid memory leak. +SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( + SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream); + +/// Destroy the pointer returned by GetOnlineStreamResult(). +/// +/// @param r A pointer returned by GetOnlineStreamResult() +void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r); + +/// Reset an OnlineStream , which clears the neural network model state +/// and the state for decoding. +/// +/// @param recognizer A pointer returned by CreateOnlineRecognizer(). +/// @param stream A pointer returned by CreateOnlineStream +void Reset(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream); + +/// Signal that no more audio samples would be available. +/// After this call, you cannot call AcceptWaveform() any more. +/// +/// @param stream A pointer returned by CreateOnlineStream() +void InputFinished(SherpaOnnxOnlineStream *stream); + +/// Return 1 if an endpoint has been detected. +/// +/// @param recognizer A pointer returned by CreateOnlineRecognizer() +/// @param stream A pointer returned by CreateOnlineStream() +/// @return Return 1 if an endpoint is detected. Return 0 otherwise. +int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, + SherpaOnnxOnlineStream *stream); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif // SHERPA_ONNX_C_API_C_API_H_