Support resampling (#77)
This commit is contained in:
@@ -78,8 +78,6 @@ def get_args():
|
||||
|
||||
|
||||
def main():
|
||||
sample_rate = 16000
|
||||
|
||||
args = get_args()
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
@@ -95,12 +93,16 @@ def main():
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=sample_rate,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
with wave.open(args.wave_filename) as f:
|
||||
assert f.getframerate() == sample_rate, f.getframerate()
|
||||
# If the wave file has a different sampling rate from the one
|
||||
# expected by the model (16 kHz in our case), we will do
|
||||
# resampling inside sherpa-onnx
|
||||
wave_file_sample_rate = f.getframerate()
|
||||
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
@@ -110,17 +112,17 @@ def main():
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
|
||||
duration = len(samples_float32) / sample_rate
|
||||
duration = len(samples_float32) / wave_file_sample_rate
|
||||
|
||||
start_time = time.time()
|
||||
print("Started!")
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
|
||||
stream.accept_waveform(sample_rate, samples_float32)
|
||||
stream.accept_waveform(wave_file_sample_rate, samples_float32)
|
||||
|
||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||
stream.accept_waveform(sample_rate, tail_paddings)
|
||||
tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32)
|
||||
stream.accept_waveform(wave_file_sample_rate, tail_paddings)
|
||||
|
||||
stream.input_finished()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user