1372 lines
74 KiB
Python
1372 lines
74 KiB
Python
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch Encodec model."""
|
|
|
|
import copy
|
|
import inspect
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from datasets import Audio, load_dataset
|
|
from parameterized import parameterized
|
|
|
|
from transformers import AutoProcessor, EncodecConfig
|
|
from transformers.testing_utils import (
|
|
is_torch_available,
|
|
require_torch,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import EncodecFeatureExtractor, EncodecModel
|
|
|
|
|
|
def prepare_inputs_dict(
|
|
config,
|
|
input_ids=None,
|
|
input_values=None,
|
|
decoder_input_ids=None,
|
|
attention_mask=None,
|
|
decoder_attention_mask=None,
|
|
head_mask=None,
|
|
decoder_head_mask=None,
|
|
cross_attn_head_mask=None,
|
|
):
|
|
if input_ids is not None:
|
|
encoder_dict = {"input_ids": input_ids}
|
|
else:
|
|
encoder_dict = {"input_values": input_values}
|
|
|
|
decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {}
|
|
|
|
return {**encoder_dict, **decoder_dict}
|
|
|
|
|
|
@require_torch
|
|
class EncodecModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
# `batch_size` needs to be an even number if the model has some outputs with batch dim != 0.
|
|
batch_size=12,
|
|
num_channels=2,
|
|
is_training=False,
|
|
intermediate_size=40,
|
|
hidden_size=32,
|
|
num_filters=8,
|
|
num_residual_layers=1,
|
|
upsampling_ratios=[8, 4],
|
|
num_lstm_layers=1,
|
|
codebook_size=64,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.num_channels = num_channels
|
|
self.is_training = is_training
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_size = hidden_size
|
|
self.num_filters = num_filters
|
|
self.num_residual_layers = num_residual_layers
|
|
self.upsampling_ratios = upsampling_ratios
|
|
self.num_lstm_layers = num_lstm_layers
|
|
self.codebook_size = codebook_size
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
|
|
config = self.get_config()
|
|
inputs_dict = {"input_values": input_values}
|
|
return config, inputs_dict
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config, inputs_dict = self.prepare_config_and_inputs()
|
|
return config, inputs_dict
|
|
|
|
def prepare_config_and_inputs_for_model_class(self, model_class):
|
|
config, inputs_dict = self.prepare_config_and_inputs()
|
|
inputs_dict["audio_codes"] = ids_tensor([1, self.batch_size, 1, self.num_channels], self.codebook_size).type(
|
|
torch.int32
|
|
)
|
|
inputs_dict["audio_scales"] = [None]
|
|
|
|
return config, inputs_dict
|
|
|
|
def prepare_config_and_inputs_for_normalization(self):
|
|
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
|
|
config = self.get_config()
|
|
config.normalize = True
|
|
|
|
processor = EncodecFeatureExtractor(feature_size=config.audio_channels, sampling_rate=config.sampling_rate)
|
|
input_values = input_values.tolist()
|
|
inputs_dict = processor(
|
|
input_values, sampling_rate=config.sampling_rate, padding=True, return_tensors="pt"
|
|
).to(torch_device)
|
|
|
|
return config, inputs_dict
|
|
|
|
def get_config(self):
|
|
return EncodecConfig(
|
|
audio_channels=self.num_channels,
|
|
chunk_in_sec=None,
|
|
hidden_size=self.hidden_size,
|
|
num_filters=self.num_filters,
|
|
num_residual_layers=self.num_residual_layers,
|
|
upsampling_ratios=self.upsampling_ratios,
|
|
num_lstm_layers=self.num_lstm_layers,
|
|
codebook_size=self.codebook_size,
|
|
)
|
|
|
|
def create_and_check_model_forward(self, config, inputs_dict):
|
|
model = EncodecModel(config=config).to(torch_device).eval()
|
|
result = model(**inputs_dict)
|
|
self.parent.assertEqual(
|
|
result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size)
|
|
)
|
|
|
|
|
|
@require_torch
|
|
class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (EncodecModel,) if is_torch_available() else ()
|
|
is_encoder_decoder = True
|
|
test_pruning = False
|
|
test_headmasking = False
|
|
test_resize_embeddings = False
|
|
pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}
|
|
|
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
|
# model does not have attention and does not support returning hidden states
|
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
|
if "output_attentions" in inputs_dict:
|
|
inputs_dict.pop("output_attentions")
|
|
if "output_hidden_states" in inputs_dict:
|
|
inputs_dict.pop("output_hidden_states")
|
|
return inputs_dict
|
|
|
|
def setUp(self):
|
|
self.model_tester = EncodecModelTester(self)
|
|
self.config_tester = ConfigTester(
|
|
self, config_class=EncodecConfig, hidden_size=37, common_properties=[], has_text_modality=False
|
|
)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model_forward(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
|
|
|
def test_forward_signature(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.forward)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = ["input_values", "padding_mask", "bandwidth"]
|
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
|
|
|
@unittest.skip(reason="The EncodecModel is not transformers based, thus it does not have `inputs_embeds` logics")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="The EncodecModel is not transformers based, thus it does not have `inputs_embeds` logics")
|
|
def test_model_get_set_embeddings(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
|
|
)
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
|
|
)
|
|
def test_torchscript_output_attentions(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="The EncodecModel is not transformers based, thus it does not have the usual `hidden_states` logic"
|
|
)
|
|
def test_torchscript_output_hidden_state(self):
|
|
pass
|
|
|
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
|
if not self.test_torchscript:
|
|
self.skipTest(reason="test_torchscript is set to False")
|
|
|
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
configs_no_init.torchscript = True
|
|
configs_no_init.return_dict = False
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
|
|
main_input_name = model_class.main_input_name
|
|
|
|
try:
|
|
main_input = inputs[main_input_name]
|
|
model(main_input)
|
|
traced_model = torch.jit.trace(model, main_input)
|
|
except RuntimeError:
|
|
self.fail("Couldn't trace module.")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
|
|
|
try:
|
|
torch.jit.save(traced_model, pt_file_name)
|
|
except Exception:
|
|
self.fail("Couldn't save module.")
|
|
|
|
try:
|
|
loaded_model = torch.jit.load(pt_file_name)
|
|
except Exception:
|
|
self.fail("Couldn't load module.")
|
|
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
loaded_model.to(torch_device)
|
|
loaded_model.eval()
|
|
|
|
model_state_dict = model.state_dict()
|
|
loaded_model_state_dict = loaded_model.state_dict()
|
|
|
|
non_persistent_buffers = {}
|
|
for key in loaded_model_state_dict:
|
|
if key not in model_state_dict:
|
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
|
|
|
loaded_model_state_dict = {
|
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
|
}
|
|
|
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
|
|
|
model_buffers = list(model.buffers())
|
|
for non_persistent_buffer in non_persistent_buffers.values():
|
|
found_buffer = False
|
|
for i, model_buffer in enumerate(model_buffers):
|
|
if torch.equal(non_persistent_buffer, model_buffer):
|
|
found_buffer = True
|
|
break
|
|
|
|
self.assertTrue(found_buffer)
|
|
model_buffers.pop(i)
|
|
|
|
model_buffers = list(model.buffers())
|
|
for non_persistent_buffer in non_persistent_buffers.values():
|
|
found_buffer = False
|
|
for i, model_buffer in enumerate(model_buffers):
|
|
if torch.equal(non_persistent_buffer, model_buffer):
|
|
found_buffer = True
|
|
break
|
|
|
|
self.assertTrue(found_buffer)
|
|
model_buffers.pop(i)
|
|
|
|
models_equal = True
|
|
for layer_name, p1 in model_state_dict.items():
|
|
if layer_name in loaded_model_state_dict:
|
|
p2 = loaded_model_state_dict[layer_name]
|
|
if p1.data.ne(p2.data).sum() > 0:
|
|
models_equal = False
|
|
|
|
self.assertTrue(models_equal)
|
|
|
|
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
|
# (Even with this call, there are still memory leak by ~0.04MB)
|
|
self.clear_torch_jit_class_registry()
|
|
|
|
@unittest.skip(
|
|
reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
|
|
)
|
|
def test_attention_outputs(self):
|
|
pass
|
|
|
|
def test_feed_forward_chunking(self):
|
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
|
# original_config.norm_type = "time_group_norm"
|
|
for model_class in self.all_model_classes:
|
|
torch.manual_seed(0)
|
|
config = copy.deepcopy(original_config)
|
|
config.chunk_length_s = None
|
|
config.overlap = None
|
|
config.sampling_rate = 20
|
|
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
inputs["input_values"] = inputs["input_values"].repeat(1, 1, 10)
|
|
|
|
hidden_states_no_chunk = model(**inputs)[1]
|
|
|
|
torch.manual_seed(0)
|
|
config.chunk_length_s = 2
|
|
config.overlap = 0
|
|
config.sampling_rate = 20
|
|
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
hidden_states_with_chunk = model(**inputs)[1]
|
|
torch.testing.assert_close(hidden_states_no_chunk, hidden_states_with_chunk, rtol=1e-1, atol=1e-2)
|
|
|
|
@unittest.skip(
|
|
reason="The EncodecModel is not transformers based, thus it does not have the usual `hidden_states` logic"
|
|
)
|
|
def test_hidden_states_output(self):
|
|
pass
|
|
|
|
def test_determinism(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
def check_determinism(first, second):
|
|
# outputs are not tensors but list (since each sequence don't have the same frame_length)
|
|
out_1 = first.cpu().numpy()
|
|
out_2 = second.cpu().numpy()
|
|
out_1 = out_1[~np.isnan(out_1)]
|
|
out_2 = out_2[~np.isnan(out_2)]
|
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
|
self.assertLessEqual(max_diff, 1e-5)
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
|
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
|
|
|
if isinstance(first, tuple) and isinstance(second, tuple):
|
|
for tensor1, tensor2 in zip(first, second):
|
|
check_determinism(tensor1, tensor2)
|
|
else:
|
|
check_determinism(first, second)
|
|
|
|
def test_model_outputs_equivalence(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
def set_nan_tensor_to_zero(t):
|
|
t[t != t] = 0
|
|
return t
|
|
|
|
def assert_nested_tensors_close(a, b):
|
|
if isinstance(a, (tuple, list)) and isinstance(b, (tuple, list)):
|
|
assert len(a) == len(b), f"Length mismatch: {len(a)} vs {len(b)}"
|
|
for i, (x, y) in enumerate(zip(a, b)):
|
|
assert_nested_tensors_close(x, y)
|
|
elif torch.is_tensor(a) and torch.is_tensor(b):
|
|
a_clean = set_nan_tensor_to_zero(a)
|
|
b_clean = set_nan_tensor_to_zero(b)
|
|
assert torch.allclose(a_clean, b_clean, atol=1e-5), (
|
|
"Tuple and dict output are not equal. Difference:"
|
|
f" Max diff: {torch.max(torch.abs(a_clean - b_clean))}. "
|
|
f"Tuple has nan: {torch.isnan(a).any()} and inf: {torch.isinf(a)}. "
|
|
f"Dict has nan: {torch.isnan(b).any()} and inf: {torch.isinf(b)}."
|
|
)
|
|
else:
|
|
raise ValueError(f"Mismatch between {a} vs {b}")
|
|
|
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
|
with torch.no_grad():
|
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
|
|
|
|
self.assertTrue(isinstance(tuple_output, tuple))
|
|
self.assertTrue(isinstance(dict_output, dict))
|
|
# cast dict_output.values() to list as it is a odict_values object
|
|
assert_nested_tensors_close(tuple_output, list(dict_output.values()))
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
|
|
|
def test_initialization(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
configs_no_init = _config_zero_init(config)
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
for name, param in model.named_parameters():
|
|
uniform_init_parms = ["conv"]
|
|
ignore_init = ["lstm"]
|
|
if param.requires_grad:
|
|
if any(x in name for x in uniform_init_parms):
|
|
self.assertTrue(
|
|
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
elif not any(x in name for x in ignore_init):
|
|
self.assertIn(
|
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
|
[0.0, 1.0],
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
|
|
def test_identity_shortcut(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
config.use_conv_shortcut = False
|
|
self.model_tester.create_and_check_model_forward(config, inputs_dict)
|
|
|
|
def test_model_forward_with_normalization(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_normalization()
|
|
self.model_tester.create_and_check_model_forward(config, inputs_dict)
|
|
|
|
|
|
def normalize(arr):
|
|
norm = np.linalg.norm(arr)
|
|
normalized_arr = arr / norm
|
|
return normalized_arr
|
|
|
|
|
|
def compute_rmse(arr1, arr2):
|
|
arr1_np = arr1.cpu().numpy().squeeze()
|
|
arr2_np = arr2.cpu().numpy().squeeze()
|
|
max_length = min(arr1.shape[-1], arr2.shape[-1])
|
|
arr1_np = arr1_np[..., :max_length]
|
|
arr2_np = arr2_np[..., :max_length]
|
|
arr1_normalized = normalize(arr1_np)
|
|
arr2_normalized = normalize(arr2_np)
|
|
return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())
|
|
|
|
|
|
"""
|
|
Integration tests for the Encodec model.
|
|
|
|
Code for expected output can be found below:
|
|
- test_integration: https://gist.github.com/ebezzam/2a34e249e729881130d1f5a42229d31f#file-test_encodec-py
|
|
- test_batch: https://gist.github.com/ebezzam/2a34e249e729881130d1f5a42229d31f#file-test_encodec_batch-py
|
|
|
|
"""
|
|
|
|
# fmt: off
|
|
# first key is model_id from hub, second key is bandwidth
|
|
# -- test_integration
|
|
EXPECTED_ENCODER_CODES = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": torch.tensor([[[ 62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
|
|
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424]]]),
|
|
"3.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
|
|
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
|
|
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
|
|
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
|
|
],
|
|
]
|
|
),
|
|
"6.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
|
|
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
|
|
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
|
|
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
|
|
[528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
|
|
[1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
|
|
[896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
|
|
[899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
|
|
|
|
],
|
|
]
|
|
),
|
|
"12.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
|
|
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
|
|
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
|
|
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
|
|
[528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
|
|
[1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
|
|
[896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
|
|
[899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
|
|
[827, 807, 938, 320, 699, 470, 909, 628, 301, 827],
|
|
[963, 801, 630, 477, 717, 354, 205, 359, 874, 744],
|
|
[1000, 1000, 388, 1000, 408, 740, 568, 364, 709, 843],
|
|
[413, 835, 382, 840, 742, 1019, 375, 962, 835, 742],
|
|
[971, 410, 998, 485, 798, 410, 351, 485, 485, 920],
|
|
[848, 694, 662, 784, 848, 427, 1022, 848, 920, 694],
|
|
[420, 911, 889, 911, 993, 776, 948, 477, 911, 911],
|
|
[587, 755, 834, 962, 860, 425, 982, 982, 425, 461],
|
|
],
|
|
]
|
|
),
|
|
"24.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
|
|
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
|
|
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
|
|
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
|
|
[528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
|
|
[1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
|
|
[896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
|
|
[899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
|
|
[827, 807, 938, 320, 699, 470, 909, 628, 301, 827],
|
|
[963, 801, 630, 477, 717, 354, 205, 359, 874, 744],
|
|
[1000, 1000, 388, 1000, 408, 740, 568, 364, 709, 843],
|
|
[413, 835, 382, 840, 742, 1019, 375, 962, 835, 742],
|
|
[971, 410, 998, 485, 798, 410, 351, 485, 485, 920],
|
|
[848, 694, 662, 784, 848, 427, 1022, 848, 920, 694],
|
|
[420, 911, 889, 911, 993, 776, 948, 477, 911, 911],
|
|
[587, 755, 834, 962, 860, 425, 982, 982, 425, 461],
|
|
[270, 160, 26, 131, 597, 506, 670, 637, 248, 160],
|
|
[ 15, 215, 134, 69, 215, 155, 1012, 1009, 260, 417],
|
|
[580, 561, 686, 896, 497, 637, 580, 245, 896, 264],
|
|
[511, 239, 560, 691, 571, 627, 571, 571, 258, 619],
|
|
[591, 942, 591, 251, 250, 250, 857, 486, 295, 295],
|
|
[565, 546, 654, 301, 301, 623, 639, 568, 565, 282],
|
|
[539, 317, 639, 539, 651, 539, 538, 640, 615, 615],
|
|
[637, 556, 637, 582, 640, 515, 515, 632, 254, 613],
|
|
[305, 643, 500, 550, 522, 500, 550, 561, 522, 305],
|
|
[954, 456, 584, 755, 505, 782, 661, 671, 497, 505],
|
|
[577, 464, 637, 647, 552, 552, 624, 647, 624, 647],
|
|
[728, 748, 931, 608, 538, 1015, 294, 294, 666, 538],
|
|
[602, 535, 666, 665, 655, 979, 574, 535, 571, 781],
|
|
[321, 620, 557, 566, 511, 910, 672, 623, 853, 674],
|
|
[621, 556, 947, 474, 610, 752, 1002, 597, 474, 474],
|
|
[605, 948, 657, 588, 485, 633, 459, 968, 939, 325],
|
|
],
|
|
]
|
|
),
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": torch.tensor([[[214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
|
|
[989, 989, 611, 77, 77, 989, 976, 976, 976, 77]]]),
|
|
"6.0": torch.tensor([[[ 214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
|
|
[ 989, 989, 611, 77, 77, 989, 976, 976, 976, 77],
|
|
[ 977, 1009, 538, 925, 925, 977, 1022, 1022, 1022, 925],
|
|
[ 376, 1012, 1023, 725, 725, 1023, 376, 962, 376, 847]]]),
|
|
"12.0": torch.tensor([[[ 214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
|
|
[ 989, 989, 611, 77, 77, 989, 976, 976, 976, 77],
|
|
[ 977, 1009, 538, 925, 925, 977, 1022, 1022, 1022, 925],
|
|
[ 376, 1012, 1023, 725, 725, 1023, 376, 962, 376, 847],
|
|
[ 979, 1012, 323, 695, 1018, 1023, 979, 1023, 979, 650],
|
|
[ 945, 762, 528, 865, 824, 945, 945, 945, 957, 957],
|
|
[ 904, 973, 1014, 681, 582, 1014, 1014, 1014, 1014, 681],
|
|
[ 229, 392, 796, 392, 977, 1017, 250, 1017, 250, 1017]]]),
|
|
"24.0": torch.tensor([[[ 214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
|
|
[ 989, 989, 611, 77, 77, 989, 976, 976, 976, 77],
|
|
[ 977, 1009, 538, 925, 925, 977, 1022, 1022, 1022, 925],
|
|
[ 376, 1012, 1023, 725, 725, 1023, 376, 962, 376, 847],
|
|
[ 979, 1012, 323, 695, 1018, 1023, 979, 1023, 979, 650],
|
|
[ 945, 762, 528, 865, 824, 945, 945, 945, 957, 957],
|
|
[ 904, 973, 1014, 681, 582, 1014, 1014, 1014, 1014, 681],
|
|
[ 229, 392, 796, 392, 977, 1017, 250, 1017, 250, 1017],
|
|
[ 902, 436, 935, 1011, 1023, 1023, 1023, 154, 1023, 392],
|
|
[ 982, 878, 961, 832, 629, 431, 919, 629, 919, 792],
|
|
[ 727, 727, 401, 727, 979, 587, 727, 487, 413, 201],
|
|
[ 928, 924, 965, 934, 840, 480, 924, 920, 924, 486],
|
|
[ 10, 625, 712, 552, 712, 259, 394, 131, 726, 516],
|
|
[ 882, 1022, 32, 524, 267, 861, 974, 882, 108, 521],
|
|
[ 304, 841, 306, 415, 69, 376, 928, 510, 381, 104],
|
|
[ 0, 0, 0, 484, 83, 0, 307, 262, 0, 0]]])
|
|
}
|
|
}
|
|
EXPECTED_ENCODER_SCALES = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": None,
|
|
"3.0": None,
|
|
"6.0": None,
|
|
"12.0": None,
|
|
"24.0": None
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
|
|
4.483359e-02, 1.000000e-08]),
|
|
"6.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
|
|
4.483359e-02, 1.000000e-08]),
|
|
"12.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
|
|
4.483359e-02, 1.000000e-08]),
|
|
"24.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
|
|
4.483359e-02, 1.000000e-08])
|
|
}
|
|
}
|
|
EXPECTED_DECODER_OUTPUTS = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": torch.tensor(
|
|
[[ 0.0003, -0.0002, -0.0000, -0.0004, 0.0004, 0.0003, -0.0000, 0.0001, 0.0005, 0.0001, -0.0015, -0.0007, -0.0002, -0.0018, -0.0003, 0.0013, 0.0011, 0.0008, 0.0008, 0.0008, 0.0008, 0.0002, -0.0003, -0.0004, -0.0006, -0.0009, -0.0010, -0.0012, -0.0011, -0.0006, -0.0006, -0.0005, 0.0000, 0.0001, 0.0003, 0.0002, -0.0001, -0.0002, -0.0008, -0.0012, -0.0011, -0.0012, -0.0013, -0.0003, 0.0002, 0.0006, 0.0006, 0.0006, 0.0009, 0.0010]]
|
|
),
|
|
"3.0": torch.tensor(
|
|
[[ 0.0003, -0.0002, -0.0000, -0.0004, 0.0004, 0.0003, -0.0000, 0.0001, 0.0006, 0.0002, -0.0015, -0.0008, -0.0002, -0.0018, -0.0003, 0.0013, 0.0011, 0.0008, 0.0008, 0.0008, 0.0008, 0.0002, -0.0003, -0.0004, -0.0005, -0.0008, -0.0010, -0.0012, -0.0011, -0.0006, -0.0006, -0.0005, -0.0000, 0.0001, 0.0003, 0.0002, -0.0001, -0.0002, -0.0008, -0.0013, -0.0011, -0.0013, -0.0014, -0.0003, 0.0002, 0.0006, 0.0006, 0.0006, 0.0009, 0.0010]]
|
|
),
|
|
"6.0": torch.tensor(
|
|
[[ 0.0004, -0.0001, 0.0001, -0.0003, 0.0004, 0.0003, 0.0000, 0.0001, 0.0007, 0.0002, -0.0013, -0.0007, -0.0002, -0.0015, -0.0001, 0.0014, 0.0014, 0.0011, 0.0010, 0.0010, 0.0009, 0.0004, 0.0000, 0.0000, 0.0000, -0.0000, -0.0001, -0.0004, -0.0004, -0.0001, -0.0002, -0.0002, 0.0002, 0.0005, 0.0009, 0.0010, 0.0008, 0.0007, 0.0002, -0.0003, -0.0004, -0.0008, -0.0008, 0.0000, 0.0006, 0.0010, 0.0012, 0.0012, 0.0013, 0.0014]]
|
|
),
|
|
"12.0": torch.tensor(
|
|
[[ 0.0004, -0.0001, 0.0001, -0.0004, 0.0003, 0.0002, -0.0000, 0.0001, 0.0006, 0.0002, -0.0013, -0.0006, -0.0001, -0.0014, 0.0001, 0.0018, 0.0018, 0.0014, 0.0012, 0.0013, 0.0011, 0.0006, 0.0000, 0.0000, -0.0000, -0.0001, -0.0001, -0.0004, -0.0004, -0.0000, -0.0000, -0.0000, 0.0005, 0.0007, 0.0011, 0.0011, 0.0009, 0.0007, 0.0002, -0.0003, -0.0004, -0.0007, -0.0007, 0.0002, 0.0009, 0.0013, 0.0015, 0.0014, 0.0015, 0.0016]]
|
|
),
|
|
"24.0": torch.tensor(
|
|
[[ 0.0005, 0.0001, 0.0004, -0.0001, 0.0003, 0.0002, 0.0000, 0.0001, 0.0007, 0.0005, -0.0011, -0.0005, -0.0001, -0.0018, -0.0000, 0.0021, 0.0019, 0.0013, 0.0011, 0.0012, 0.0012, 0.0006, -0.0000, -0.0001, -0.0000, -0.0000, -0.0001, -0.0004, -0.0004, -0.0000, -0.0001, -0.0002, 0.0003, 0.0004, 0.0008, 0.0007, 0.0006, 0.0007, 0.0001, -0.0004, -0.0003, -0.0006, -0.0008, 0.0004, 0.0011, 0.0015, 0.0016, 0.0015, 0.0016, 0.0018]]
|
|
)
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": torch.tensor(
|
|
[
|
|
[0.0034, 0.0028, 0.0037, 0.0041, 0.0029, 0.0022, 0.0021, 0.0020, 0.0021, 0.0023, 0.0021, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0021, 0.0023, 0.0025, 0.0022, 0.0017, 0.0015, 0.0017, 0.0020, 0.0024, 0.0031, 0.0039, 0.0045, 0.0046, 0.0042, 0.0034, 0.0027, 0.0023, 0.0022, 0.0023, 0.0024, 0.0022, 0.0023, 0.0024, 0.0027, 0.0027, 0.0027, 0.0025, 0.0024, 0.0024, 0.0026, 0.0028, 0.0027, 0.0024, 0.0022],
|
|
[ -0.0031, -0.0027, -0.0018, -0.0017, -0.0024, -0.0029, -0.0030, -0.0026, -0.0021, -0.0018, -0.0018, -0.0019, -0.0017, -0.0014, -0.0012, -0.0010, -0.0008, -0.0004, -0.0001, -0.0004, -0.0012, -0.0015, -0.0014, -0.0013, -0.0011, -0.0005, 0.0002, 0.0007, 0.0008, 0.0004, -0.0003, -0.0010, -0.0012, -0.0011, -0.0009, -0.0009, -0.0009, -0.0008, -0.0006, -0.0005, -0.0005, -0.0005, -0.0006, -0.0008, -0.0008, -0.0006, -0.0005, -0.0007, -0.0010, -0.0012],
|
|
]
|
|
),
|
|
"6.0": torch.tensor(
|
|
[
|
|
[0.0052, 0.0049, 0.0057, 0.0058, 0.0048, 0.0043, 0.0042, 0.0041, 0.0041, 0.0042, 0.0040, 0.0038, 0.0038, 0.0038, 0.0037, 0.0037, 0.0037, 0.0037, 0.0038, 0.0037, 0.0035, 0.0034, 0.0036, 0.0039, 0.0043, 0.0047, 0.0053, 0.0057, 0.0057, 0.0055, 0.0050, 0.0046, 0.0043, 0.0041, 0.0042, 0.0042, 0.0041, 0.0041, 0.0042, 0.0043, 0.0043, 0.0043, 0.0041, 0.0040, 0.0040, 0.0041, 0.0042, 0.0042, 0.0040, 0.0039],
|
|
[ 0.0001, 0.0006, 0.0013, 0.0011, 0.0005, 0.0001, -0.0001, 0.0001, 0.0003, 0.0005, 0.0005, 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.0010, 0.0013, 0.0015, 0.0014, 0.0010, 0.0008, 0.0010, 0.0012, 0.0015, 0.0019, 0.0023, 0.0026, 0.0026, 0.0024, 0.0020, 0.0016, 0.0013, 0.0013, 0.0014, 0.0015, 0.0015, 0.0016, 0.0017, 0.0017, 0.0017, 0.0016, 0.0015, 0.0013, 0.0013, 0.0013, 0.0013, 0.0012, 0.0010, 0.0009],
|
|
]
|
|
),
|
|
"12.0": torch.tensor(
|
|
[
|
|
[0.0014, 0.0012, 0.0021, 0.0024, 0.0017, 0.0013, 0.0012, 0.0011, 0.0011, 0.0012, 0.0011, 0.0010, 0.0009, 0.0009, 0.0008, 0.0008, 0.0009, 0.0010, 0.0012, 0.0012, 0.0009, 0.0008, 0.0010, 0.0013, 0.0017, 0.0024, 0.0031, 0.0036, 0.0036, 0.0033, 0.0028, 0.0023, 0.0020, 0.0020, 0.0022, 0.0022, 0.0022, 0.0022, 0.0023, 0.0024, 0.0024, 0.0023, 0.0021, 0.0021, 0.0021, 0.0023, 0.0024, 0.0024, 0.0022, 0.0021],
|
|
[ -0.0034, -0.0029, -0.0020, -0.0020, -0.0024, -0.0027, -0.0030, -0.0030, -0.0028, -0.0025, -0.0025, -0.0025, -0.0025, -0.0025, -0.0023, -0.0022, -0.0020, -0.0017, -0.0013, -0.0014, -0.0017, -0.0019, -0.0018, -0.0015, -0.0011, -0.0006, 0.0000, 0.0005, 0.0005, 0.0002, -0.0003, -0.0008, -0.0010, -0.0009, -0.0007, -0.0006, -0.0006, -0.0005, -0.0005, -0.0005, -0.0005, -0.0007, -0.0008, -0.0009, -0.0009, -0.0008, -0.0007, -0.0008, -0.0010, -0.0011],
|
|
]
|
|
),
|
|
"24.0": torch.tensor(
|
|
[
|
|
[ 0.0010, 0.0008, 0.0018, 0.0021, 0.0014, 0.0011, 0.0009, 0.0007, 0.0006, 0.0006, 0.0005, 0.0003, 0.0003, 0.0002, 0.0001, 0.0001, 0.0001, 0.0002, 0.0002, 0.0001, -0.0002, -0.0004, -0.0003, 0.0000, 0.0005, 0.0011, 0.0018, 0.0022, 0.0022, 0.0018, 0.0012, 0.0007, 0.0004, 0.0003, 0.0004, 0.0006, 0.0006, 0.0007, 0.0007, 0.0009, 0.0008, 0.0007, 0.0005, 0.0004, 0.0004, 0.0006, 0.0007, 0.0007, 0.0005, 0.0004],
|
|
[-0.0039, -0.0035, -0.0027, -0.0026, -0.0028, -0.0031, -0.0035, -0.0035, -0.0034, -0.0033, -0.0032, -0.0032, -0.0031, -0.0031, -0.0029, -0.0028, -0.0026, -0.0024, -0.0021, -0.0021, -0.0024, -0.0025, -0.0024, -0.0021, -0.0017, -0.0011, -0.0006, -0.0002, -0.0002, -0.0004, -0.0009, -0.0013, -0.0015, -0.0015, -0.0014, -0.0013, -0.0012, -0.0011, -0.0010, -0.0010, -0.0011, -0.0012, -0.0014, -0.0015, -0.0015, -0.0014, -0.0013, -0.0014, -0.0016, -0.0017],
|
|
]
|
|
)
|
|
}
|
|
}
|
|
EXPECTED_CODEC_ERROR = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": 0.0022229827009141445,
|
|
"3.0": 0.001862662611529231,
|
|
"6.0": 0.0015231302240863442,
|
|
"12.0": 0.0013,
|
|
"24.0": 0.0012,
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": 0.000840399123262614,
|
|
"6.0": 0.0006692984024994075,
|
|
"12.0": 0.0005328940460458398,
|
|
"24.0": 0.0004473362350836396,
|
|
}
|
|
}
|
|
# -- test_batch
|
|
EXPECTED_ENCODER_CODES_BATCH = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": torch.tensor(
|
|
[
|
|
[
|
|
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475, 475, 404, 475, 475, 475, 835, 475, 475, 835, 835,
|
|
106, 106, 738, 106, 738, 106, 408, 408, 738, 408, 408, 408, 738, 408, 408, 408, 408, 738, 408,
|
|
1017, 604, 64, 303, 394, 5, 570, 991, 570, 969, 814],
|
|
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007, 729, 1007, 961, 1007, 1007, 961, 969, 1007,
|
|
1007, 424, 518, 1007, 544, 1007, 518, 913, 424, 424, 544, 424, 518, 518, 518, 302, 424, 424, 424,
|
|
544, 424, 114, 200, 787, 931, 343, 434, 315, 487, 872, 769, 463],
|
|
|
|
],
|
|
[
|
|
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 408, 835, 738, 408, 408, 408, 408, 408,
|
|
408, 738, 408, 408, 408, 408, 408, 408, 408, 408, 738, 408, 408, 408, 408, 408, 408, 408, 408, 408,
|
|
339, 834, 819, 875, 957, 670, 811, 670, 237, 53],
|
|
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913, 518, 913, 518, 518, 544, 424, 424, 518, 424, 424,
|
|
424, 544, 424, 424, 424, 518, 424, 518, 518, 937, 544, 424, 518, 302, 518, 424, 424, 518, 424, 424,
|
|
913, 857, 841, 363, 463, 78, 176, 645, 255, 571],
|
|
|
|
],
|
|
|
|
]
|
|
|
|
),
|
|
"3.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
|
|
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
|
|
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
|
|
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
|
|
|
|
],
|
|
[
|
|
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
|
|
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
|
|
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
|
|
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
|
|
|
|
],
|
|
]
|
|
),
|
|
"6.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
|
|
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
|
|
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
|
|
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
|
|
[712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
|
|
[939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
|
|
[853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
|
|
[899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
|
|
|
|
],
|
|
[
|
|
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
|
|
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
|
|
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
|
|
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
|
|
[373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
|
|
[293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
|
|
[800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
|
|
[916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
|
|
|
|
],
|
|
]
|
|
),
|
|
"12.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
|
|
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
|
|
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
|
|
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
|
|
[712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
|
|
[939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
|
|
[853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
|
|
[899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
|
|
[817, 470, 588, 675, 675, 588, 960, 927, 909, 466],
|
|
[953, 776, 717, 630, 359, 717, 861, 630, 861, 359],
|
|
[623, 740, 1000, 388, 420, 388, 740, 818, 958, 743],
|
|
[413, 835, 742, 249, 892, 352, 190, 498, 866, 890],
|
|
[817, 351, 804, 751, 938, 535, 434, 879, 351, 971],
|
|
[792, 495, 935, 848, 792, 795, 942, 935, 723, 531],
|
|
[622, 681, 477, 713, 752, 871, 713, 514, 993, 777],
|
|
[928, 799, 962, 1005, 860, 439, 312, 922, 982, 922],
|
|
],
|
|
[
|
|
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
|
|
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
|
|
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
|
|
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
|
|
[373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
|
|
[293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
|
|
[800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
|
|
[916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
|
|
[420, 628, 918, 628, 628, 628, 248, 628, 909, 811],
|
|
[736, 717, 994, 974, 477, 874, 963, 979, 355, 979],
|
|
[1002, 1002, 894, 875, 388, 709, 534, 408, 881, 709],
|
|
[735, 828, 763, 742, 640, 835, 828, 375, 840, 375],
|
|
[898, 938, 556, 658, 410, 951, 486, 658, 877, 877],
|
|
[ 0, 797, 428, 694, 428, 920, 1022, 1022, 809, 797],
|
|
[622, 421, 422, 776, 911, 911, 958, 421, 776, 421],
|
|
[1005, 312, 922, 755, 834, 461, 461, 702, 597, 974],
|
|
|
|
],
|
|
]
|
|
),
|
|
"24.0": torch.tensor(
|
|
[
|
|
[
|
|
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
|
|
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
|
|
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
|
|
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
|
|
[712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
|
|
[939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
|
|
[853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
|
|
[899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
|
|
[817, 470, 588, 675, 675, 588, 960, 927, 909, 466],
|
|
[953, 776, 717, 630, 359, 717, 861, 630, 861, 359],
|
|
[623, 740, 1000, 388, 420, 388, 740, 818, 958, 743],
|
|
[413, 835, 742, 249, 892, 352, 190, 498, 866, 890],
|
|
[817, 351, 804, 751, 938, 535, 434, 879, 351, 971],
|
|
[792, 495, 935, 848, 792, 795, 942, 935, 723, 531],
|
|
[622, 681, 477, 713, 752, 871, 713, 514, 993, 777],
|
|
[928, 799, 962, 1005, 860, 439, 312, 922, 982, 922],
|
|
[939, 637, 861, 506, 861, 61, 475, 264, 1019, 260],
|
|
[166, 215, 69, 69, 890, 69, 284, 828, 396, 180],
|
|
[561, 896, 841, 144, 580, 659, 886, 514, 686, 451],
|
|
[691, 691, 239, 735, 62, 287, 383, 972, 550, 505],
|
|
[451, 811, 238, 251, 250, 841, 734, 329, 551, 846],
|
|
[313, 601, 494, 763, 811, 565, 748, 441, 601, 480],
|
|
[653, 242, 630, 572, 701, 973, 632, 374, 561, 521],
|
|
[984, 987, 419, 454, 386, 507, 532, 636, 515, 671],
|
|
[647, 550, 515, 292, 876, 1011, 719, 549, 691, 911],
|
|
[683, 536, 656, 603, 698, 867, 987, 857, 886, 491],
|
|
[444, 937, 826, 555, 585, 710, 466, 852, 655, 591],
|
|
[658, 952, 903, 508, 739, 596, 420, 721, 464, 306],
|
|
[665, 334, 765, 532, 618, 278, 836, 838, 517, 597],
|
|
[613, 674, 596, 904, 987, 977, 938, 615, 672, 776],
|
|
[689, 386, 749, 658, 250, 869, 957, 806, 750, 659],
|
|
[652, 509, 910, 826, 566, 622, 951, 696, 900, 895],
|
|
],
|
|
[
|
|
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
|
|
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
|
|
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
|
|
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
|
|
[373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
|
|
[293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
|
|
[800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
|
|
[916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
|
|
[420, 628, 918, 628, 628, 628, 248, 628, 909, 811],
|
|
[736, 717, 994, 974, 477, 874, 963, 979, 355, 979],
|
|
[1002, 1002, 894, 875, 388, 709, 534, 408, 881, 709],
|
|
[735, 828, 763, 742, 640, 835, 828, 375, 840, 375],
|
|
[898, 938, 556, 658, 410, 951, 486, 658, 877, 877],
|
|
[ 0, 797, 428, 694, 428, 920, 1022, 1022, 809, 797],
|
|
[622, 421, 422, 776, 911, 911, 958, 421, 776, 421],
|
|
[1005, 312, 922, 755, 834, 461, 461, 702, 597, 974],
|
|
[248, 248, 637, 248, 977, 506, 546, 270, 670, 506],
|
|
[547, 447, 15, 134, 1009, 215, 134, 396, 260, 160],
|
|
[635, 497, 686, 765, 264, 497, 244, 675, 624, 656],
|
|
[864, 571, 616, 511, 588, 781, 525, 258, 674, 503],
|
|
[449, 757, 857, 451, 658, 486, 299, 299, 251, 596],
|
|
[809, 628, 255, 568, 623, 301, 639, 546, 617, 623],
|
|
[551, 497, 908, 539, 661, 710, 640, 539, 646, 315],
|
|
[689, 507, 875, 515, 613, 637, 527, 515, 662, 637],
|
|
[983, 686, 456, 768, 601, 561, 768, 653, 500, 688],
|
|
[493, 566, 664, 782, 683, 683, 721, 603, 323, 497],
|
|
[1015, 552, 411, 423, 607, 646, 687, 1018, 689, 607],
|
|
[516, 293, 471, 294, 293, 294, 608, 538, 803, 717],
|
|
[974, 994, 952, 637, 637, 927, 535, 571, 602, 535],
|
|
[776, 789, 476, 944, 652, 959, 589, 679, 321, 623],
|
|
[776, 931, 720, 1009, 676, 731, 386, 676, 701, 676],
|
|
[684, 543, 716, 392, 661, 517, 792, 588, 922, 676],
|
|
],
|
|
]
|
|
)
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": torch.tensor([[[790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
|
|
[989, 989, 77, 546, 989, 546, 989, 160, 546, 989]],
|
|
[[214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
|
|
[289, 289, 989, 764, 289, 289, 882, 882, 882, 882]]]),
|
|
"6.0": torch.tensor([[[ 790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
|
|
[ 989, 989, 77, 546, 989, 546, 989, 160, 546, 989],
|
|
[ 977, 977, 977, 977, 538, 977, 977, 960, 977, 977],
|
|
[ 376, 376, 962, 962, 607, 962, 963, 896, 962, 376]],
|
|
[[ 214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
|
|
[ 289, 289, 989, 764, 289, 289, 882, 882, 882, 882],
|
|
[1022, 1022, 471, 925, 821, 821, 267, 925, 925, 267],
|
|
[ 979, 992, 914, 921, 0, 0, 1023, 963, 963, 1023]]]),
|
|
"12.0": torch.tensor([[[ 790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
|
|
[ 989, 989, 77, 546, 989, 546, 989, 160, 546, 989],
|
|
[ 977, 977, 977, 977, 538, 977, 977, 960, 977, 977],
|
|
[ 376, 376, 962, 962, 607, 962, 963, 896, 962, 376],
|
|
[ 979, 979, 979, 1012, 979, 1012, 921, 0, 1002, 695],
|
|
[ 824, 1018, 762, 957, 824, 762, 762, 1007, 957, 336],
|
|
[ 681, 973, 973, 452, 211, 681, 802, 679, 547, 884],
|
|
[ 950, 1017, 1016, 1017, 986, 1017, 229, 607, 1017, 689]],
|
|
|
|
[[ 214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
|
|
[ 289, 289, 989, 764, 289, 289, 882, 882, 882, 882],
|
|
[1022, 1022, 471, 925, 821, 821, 267, 925, 925, 267],
|
|
[ 979, 992, 914, 921, 0, 0, 1023, 963, 963, 1023],
|
|
[ 403, 940, 976, 1018, 677, 1002, 979, 677, 677, 677],
|
|
[1018, 794, 762, 444, 485, 485, 974, 548, 548, 1018],
|
|
[ 679, 243, 679, 1005, 1005, 973, 1014, 1005, 1005, 1014],
|
|
[ 810, 13, 1017, 537, 522, 702, 202, 1017, 1017, 15]]]),
|
|
"24.0": torch.tensor(
|
|
[
|
|
[
|
|
[790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
|
|
[989, 989, 77, 546, 989, 546, 989, 160, 546, 989],
|
|
[977, 977, 977, 977, 538, 977, 977, 960, 977, 977],
|
|
[376, 376, 962, 962, 607, 962, 963, 896, 962, 376],
|
|
[979, 979, 979, 1012, 979, 1012, 921, 0, 1002, 695],
|
|
[824, 1018, 762, 957, 824, 762, 762, 1007, 957, 336],
|
|
[681, 973, 973, 452, 211, 681, 802, 679, 547, 884],
|
|
[950, 1017, 1016, 1017, 986, 1017, 229, 607, 1017, 689],
|
|
[1004, 1011, 669, 1023, 1023, 1023, 905, 297, 810, 970],
|
|
[982, 681, 982, 629, 662, 919, 878, 476, 629, 982],
|
|
[727, 727, 959, 959, 979, 959, 530, 959, 337, 961],
|
|
[924, 456, 924, 486, 924, 959, 102, 924, 805, 924],
|
|
[649, 542, 993, 993, 949, 787, 56, 886, 949, 405],
|
|
[864, 1022, 1022, 1022, 460, 753, 805, 309, 1022, 32],
|
|
[953, 0, 0, 180, 352, 10, 581, 516, 322, 452],
|
|
[300, 0, 1020, 307, 0, 543, 924, 627, 258, 262],
|
|
],
|
|
[
|
|
[214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
|
|
[289, 289, 989, 764, 289, 289, 882, 882, 882, 882],
|
|
[1022, 1022, 471, 925, 821, 821, 267, 925, 925, 267],
|
|
[979, 992, 914, 921, 0, 0, 1023, 963, 963, 1023],
|
|
[403, 940, 976, 1018, 677, 1002, 979, 677, 677, 677],
|
|
[1018, 794, 762, 444, 485, 485, 974, 548, 548, 1018],
|
|
[679, 243, 679, 1005, 1005, 973, 1014, 1005, 1005, 1014],
|
|
[810, 13, 1017, 537, 522, 702, 202, 1017, 1017, 15],
|
|
[728, 252, 970, 984, 971, 950, 673, 902, 1011, 810],
|
|
[332, 1014, 476, 854, 1014, 861, 332, 411, 411, 408],
|
|
[959, 727, 611, 979, 611, 727, 999, 497, 821, 0],
|
|
[995, 698, 924, 688, 102, 510, 924, 970, 344, 961],
|
|
[ 81, 516, 847, 924, 10, 240, 1005, 726, 993, 378],
|
|
[467, 496, 484, 496, 456, 1022, 337, 600, 456, 1022],
|
|
[789, 65, 937, 976, 159, 953, 343, 764, 179, 159],
|
|
[ 10, 790, 483, 10, 1020, 352, 848, 333, 83, 848],
|
|
],
|
|
]
|
|
)
|
|
}
|
|
}
|
|
EXPECTED_ENCODER_SCALES_BATCH = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": None,
|
|
"3.0": None,
|
|
"6.0": None,
|
|
"12.0": None,
|
|
"24.0": None
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": torch.tensor([[[1.027247e-01],
|
|
[7.877284e-02]],
|
|
[[1.014922e-01],
|
|
[8.696266e-02]],
|
|
[[6.308002e-02],
|
|
[7.748771e-02]],
|
|
[[6.899278e-02],
|
|
[1.045912e-01]],
|
|
[[6.440169e-02],
|
|
[8.843135e-02]],
|
|
[[4.139878e-02],
|
|
[1.000000e-08]],
|
|
[[5.848629e-02],
|
|
[1.000000e-08]],
|
|
[[2.329416e-04],
|
|
[1.000000e-08]],
|
|
[[1.000000e-08],
|
|
[1.000000e-08]]]),
|
|
"6.0": torch.tensor([[[1.027247e-01],
|
|
[7.877284e-02]],
|
|
[[1.014922e-01],
|
|
[8.696266e-02]],
|
|
[[6.308002e-02],
|
|
[7.748771e-02]],
|
|
[[6.899278e-02],
|
|
[1.045912e-01]],
|
|
[[6.440169e-02],
|
|
[8.843135e-02]],
|
|
[[4.139878e-02],
|
|
[1.000000e-08]],
|
|
[[5.848629e-02],
|
|
[1.000000e-08]],
|
|
[[2.329416e-04],
|
|
[1.000000e-08]],
|
|
[[1.000000e-08],
|
|
[1.000000e-08]]]),
|
|
"12.0": torch.tensor([[[1.027247e-01],
|
|
[7.877284e-02]],
|
|
[[1.014922e-01],
|
|
[8.696266e-02]],
|
|
[[6.308002e-02],
|
|
[7.748771e-02]],
|
|
[[6.899278e-02],
|
|
[1.045912e-01]],
|
|
[[6.440169e-02],
|
|
[8.843135e-02]],
|
|
[[4.139878e-02],
|
|
[1.000000e-08]],
|
|
[[5.848629e-02],
|
|
[1.000000e-08]],
|
|
[[2.329416e-04],
|
|
[1.000000e-08]],
|
|
[[1.000000e-08],
|
|
[1.000000e-08]]]),
|
|
"24.0": torch.tensor([[[1.027247e-01],
|
|
[7.877284e-02]],
|
|
[[1.014922e-01],
|
|
[8.696266e-02]],
|
|
[[6.308002e-02],
|
|
[7.748771e-02]],
|
|
[[6.899278e-02],
|
|
[1.045912e-01]],
|
|
[[6.440169e-02],
|
|
[8.843135e-02]],
|
|
[[4.139878e-02],
|
|
[1.000000e-08]],
|
|
[[5.848629e-02],
|
|
[1.000000e-08]],
|
|
[[2.329416e-04],
|
|
[1.000000e-08]],
|
|
[[1.000000e-08],
|
|
[1.000000e-08]]])
|
|
}
|
|
}
|
|
EXPECTED_DECODER_OUTPUTS_BATCH = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": torch.tensor(
|
|
[
|
|
[[ 0.0010, 0.0004, 0.0005, 0.0002, 0.0005, -0.0001, -0.0003, -0.0001, 0.0003, 0.0001, -0.0014, -0.0009, -0.0007, -0.0023, -0.0009, 0.0008, 0.0007, 0.0003, 0.0001, 0.0001, 0.0003, -0.0001, -0.0003, -0.0004, -0.0005, -0.0007, -0.0009, -0.0011, -0.0010, -0.0006, -0.0007, -0.0007, -0.0005, -0.0005, -0.0003, -0.0002, -0.0002, -0.0001, -0.0005, -0.0008, -0.0005, -0.0007, -0.0009, -0.0002, 0.0003, 0.0005, 0.0004, 0.0001, 0.0003, 0.0004]],
|
|
[[ -0.0001, -0.0000, 0.0003, 0.0001, 0.0005, 0.0001, -0.0006, -0.0002, 0.0002, 0.0002, -0.0031, -0.0004, 0.0006, -0.0066, -0.0032, 0.0044, 0.0025, -0.0019, -0.0017, 0.0001, 0.0019, -0.0010, -0.0014, -0.0009, -0.0007, -0.0009, -0.0019, -0.0024, -0.0019, -0.0001, -0.0017, -0.0022, -0.0004, 0.0005, -0.0014, -0.0023, 0.0002, 0.0015, -0.0022, -0.0033, 0.0024, 0.0009, -0.0041, 0.0000, 0.0030, 0.0020, -0.0015, -0.0018, 0.0014, 0.0007]],
|
|
]
|
|
),
|
|
"3.0": torch.tensor(
|
|
[
|
|
[[ 0.0013, 0.0007, 0.0009, 0.0005, 0.0006, 0.0002, -0.0001, 0.0000, 0.0005, 0.0003, -0.0012, -0.0006, -0.0003, -0.0019, -0.0003, 0.0015, 0.0013, 0.0009, 0.0008, 0.0007, 0.0008, 0.0004, 0.0001, -0.0000, -0.0001, -0.0002, -0.0003, -0.0004, -0.0004, 0.0001, -0.0000, -0.0000, 0.0003, 0.0003, 0.0005, 0.0005, 0.0004, 0.0005, 0.0001, -0.0003, -0.0002, -0.0004, -0.0006, 0.0003, 0.0009, 0.0012, 0.0013, 0.0012, 0.0014, 0.0015]],
|
|
[[ 0.0000, -0.0003, 0.0005, 0.0004, 0.0011, 0.0013, 0.0002, 0.0005, 0.0002, 0.0006, -0.0025, -0.0005, 0.0004, -0.0069, -0.0027, 0.0038, 0.0013, -0.0015, -0.0005, 0.0003, 0.0014, -0.0006, -0.0002, -0.0010, -0.0008, -0.0001, -0.0006, -0.0012, -0.0016, 0.0010, 0.0001, -0.0010, -0.0002, 0.0013, -0.0002, -0.0017, 0.0005, 0.0019, -0.0019, -0.0035, 0.0022, -0.0001, -0.0040, 0.0012, 0.0015, 0.0012, 0.0001, -0.0010, 0.0005, 0.0004]],
|
|
]
|
|
),
|
|
"6.0": torch.tensor(
|
|
[
|
|
[[ 0.0010, 0.0005, 0.0007, 0.0001, 0.0003, -0.0000, -0.0002, -0.0001, 0.0003, 0.0001, -0.0014, -0.0007, -0.0004, -0.0019, -0.0004, 0.0013, 0.0012, 0.0008, 0.0007, 0.0007, 0.0008, 0.0003, 0.0001, 0.0001, -0.0000, -0.0001, -0.0001, -0.0002, -0.0001, 0.0002, 0.0002, 0.0001, 0.0005, 0.0005, 0.0008, 0.0008, 0.0007, 0.0008, 0.0004, 0.0001, 0.0002, -0.0001, -0.0002, 0.0006, 0.0012, 0.0015, 0.0016, 0.0014, 0.0016, 0.0017]],
|
|
[[ -0.0005, -0.0001, 0.0003, 0.0001, 0.0010, 0.0012, 0.0002, 0.0004, 0.0012, 0.0003, -0.0023, -0.0003, -0.0005, -0.0063, -0.0026, 0.0040, 0.0024, -0.0018, -0.0005, 0.0016, 0.0004, -0.0008, 0.0009, 0.0002, -0.0015, -0.0003, 0.0004, -0.0011, -0.0013, 0.0012, 0.0001, -0.0019, 0.0007, 0.0021, -0.0009, -0.0016, 0.0015, 0.0013, -0.0022, -0.0015, 0.0016, -0.0014, -0.0033, 0.0017, 0.0025, -0.0004, -0.0005, 0.0010, 0.0005, 0.0001]],
|
|
]
|
|
),
|
|
"12.0": torch.tensor(
|
|
[
|
|
[[ 0.0003, 0.0002, 0.0004, -0.0004, -0.0003, -0.0007, -0.0008, -0.0006, -0.0001, -0.0002, -0.0016, -0.0009, -0.0004, -0.0021, -0.0003, 0.0015, 0.0016, 0.0012, 0.0011, 0.0010, 0.0010, 0.0005, 0.0002, 0.0001, 0.0000, -0.0001, -0.0002, -0.0004, -0.0004, 0.0000, -0.0000, -0.0002, 0.0001, 0.0001, 0.0004, 0.0003, 0.0002, 0.0004, -0.0001, -0.0005, -0.0004, -0.0006, -0.0007, 0.0003, 0.0009, 0.0013, 0.0015, 0.0015, 0.0017, 0.0018]],
|
|
[[ -0.0008, -0.0003, 0.0003, -0.0001, 0.0008, 0.0013, 0.0004, 0.0008, 0.0015, 0.0006, -0.0021, -0.0001, -0.0003, -0.0062, -0.0022, 0.0043, 0.0028, -0.0013, -0.0002, 0.0017, 0.0010, -0.0001, 0.0008, 0.0001, -0.0010, 0.0003, 0.0008, -0.0006, -0.0007, 0.0012, 0.0003, -0.0013, 0.0007, 0.0019, -0.0002, -0.0013, 0.0011, 0.0016, -0.0016, -0.0017, 0.0014, -0.0006, -0.0029, 0.0011, 0.0028, 0.0006, -0.0004, 0.0005, 0.0008, 0.0003]],
|
|
]
|
|
),
|
|
"24.0": torch.tensor(
|
|
[
|
|
[[ 0.0009, 0.0004, 0.0007, 0.0002, 0.0004, -0.0001, -0.0003, -0.0002, 0.0002, 0.0001, -0.0015, -0.0009, -0.0006, -0.0024, -0.0005, 0.0016, 0.0014, 0.0010, 0.0009, 0.0008, 0.0008, 0.0004, 0.0001, 0.0000, -0.0001, -0.0002, -0.0003, -0.0006, -0.0006, -0.0003, -0.0005, -0.0006, -0.0003, -0.0004, -0.0001, -0.0002, -0.0003, -0.0001, -0.0006, -0.0011, -0.0008, -0.0010, -0.0012, -0.0000, 0.0007, 0.0011, 0.0012, 0.0011, 0.0013, 0.0014]],
|
|
[[ -0.0009, -0.0004, 0.0001, -0.0003, 0.0007, 0.0012, 0.0003, 0.0006, 0.0017, 0.0008, -0.0020, 0.0001, -0.0002, -0.0064, -0.0023, 0.0047, 0.0029, -0.0016, -0.0004, 0.0019, 0.0010, -0.0002, 0.0007, -0.0001, -0.0013, 0.0005, 0.0012, -0.0007, -0.0008, 0.0013, -0.0001, -0.0022, 0.0004, 0.0020, -0.0004, -0.0014, 0.0017, 0.0020, -0.0018, -0.0016, 0.0015, -0.0015, -0.0036, 0.0014, 0.0030, 0.0004, 0.0002, 0.0015, 0.0011, 0.0007]],
|
|
]
|
|
)
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": torch.tensor([[[ 0.005083, 0.004669, 0.005723, 0.005600, 0.004231, 0.003830,
|
|
0.003684, 0.003349, 0.003032, 0.003055, 0.002768, 0.002370,
|
|
0.002384, 0.002450, 0.002391, 0.002363, 0.002357, 0.002435,
|
|
0.002568, 0.002463, 0.002137, 0.002092, 0.002440, 0.002772,
|
|
0.003035, 0.003473, 0.003963, 0.004288, 0.004315, 0.004087,
|
|
0.003618, 0.003166, 0.002874, 0.002775, 0.002820, 0.002758,
|
|
0.002565, 0.002498, 0.002583, 0.002671, 0.002656, 0.002613,
|
|
0.002433, 0.002236, 0.002215, 0.002302, 0.002287, 0.002113,
|
|
0.001909, 0.001767],
|
|
[-0.003928, -0.002733, -0.001330, -0.001914, -0.002927, -0.003272,
|
|
-0.003677, -0.003615, -0.003341, -0.002907, -0.002764, -0.002742,
|
|
-0.002593, -0.002308, -0.002024, -0.001856, -0.001672, -0.001256,
|
|
-0.000929, -0.001217, -0.001864, -0.002118, -0.002025, -0.001932,
|
|
-0.001816, -0.001572, -0.001214, -0.000885, -0.000829, -0.000976,
|
|
-0.001417, -0.001874, -0.002030, -0.001952, -0.001858, -0.001863,
|
|
-0.001895, -0.001843, -0.001801, -0.001792, -0.001812, -0.001865,
|
|
-0.002008, -0.002120, -0.002132, -0.002093, -0.002170, -0.002370,
|
|
-0.002587, -0.002749]],
|
|
[[ 0.004229, 0.003422, 0.005044, 0.006059, 0.005242, 0.004623,
|
|
0.004231, 0.004050, 0.004314, 0.004701, 0.004559, 0.004105,
|
|
0.003874, 0.003713, 0.003355, 0.003055, 0.003235, 0.003927,
|
|
0.004500, 0.004195, 0.003328, 0.002804, 0.002628, 0.002456,
|
|
0.002693, 0.003883, 0.005604, 0.006791, 0.006702, 0.005427,
|
|
0.003622, 0.002328, 0.002173, 0.002871, 0.003505, 0.003410,
|
|
0.002851, 0.002511, 0.002534, 0.002685, 0.002714, 0.002538,
|
|
0.002110, 0.001697, 0.001786, 0.002415, 0.002940, 0.002856,
|
|
0.002348, 0.001883],
|
|
[-0.003444, -0.002916, -0.000590, 0.000157, -0.000702, -0.001472,
|
|
-0.002032, -0.001891, -0.001283, -0.000670, -0.000590, -0.000875,
|
|
-0.001090, -0.001095, -0.001172, -0.001287, -0.000907, 0.000111,
|
|
0.000858, 0.000471, -0.000532, -0.001127, -0.001463, -0.001853,
|
|
-0.001762, -0.000666, 0.000964, 0.002054, 0.001914, 0.000743,
|
|
-0.000876, -0.001990, -0.001951, -0.001042, -0.000229, -0.000171,
|
|
-0.000558, -0.000752, -0.000704, -0.000609, -0.000594, -0.000723,
|
|
-0.001085, -0.001455, -0.001374, -0.000795, -0.000350, -0.000480,
|
|
-0.000993, -0.001432]]]),
|
|
"6.0": torch.tensor([[[ 5.892794e-03, 5.767163e-03, 7.065284e-03, 7.068626e-03,
|
|
5.825328e-03, 5.601424e-03, 5.582351e-03, 5.209565e-03,
|
|
4.829186e-03, 4.809568e-03, 4.663883e-03, 4.402087e-03,
|
|
4.337528e-03, 4.311915e-03, 4.236566e-03, 4.209972e-03,
|
|
4.179818e-03, 4.196202e-03, 4.309553e-03, 4.267083e-03,
|
|
4.052189e-03, 4.068719e-03, 4.381632e-03, 4.692366e-03,
|
|
4.998885e-03, 5.466312e-03, 5.895300e-03, 6.115717e-03,
|
|
6.055626e-03, 5.773376e-03, 5.316667e-03, 4.826934e-03,
|
|
4.450697e-03, 4.315911e-03, 4.310716e-03, 4.202125e-03,
|
|
4.008702e-03, 3.957694e-03, 4.017603e-03, 4.060654e-03,
|
|
4.036821e-03, 3.923071e-03, 3.659022e-03, 3.427053e-03,
|
|
3.387271e-03, 3.462438e-03, 3.434755e-03, 3.247944e-03,
|
|
3.009581e-03, 2.800536e-03],
|
|
[-1.867314e-03, -6.082351e-04, 9.374358e-04, 5.555808e-04,
|
|
-3.020080e-04, -5.281629e-04, -9.364292e-04, -1.057594e-03,
|
|
-9.703087e-04, -6.292185e-04, -4.193477e-04, -3.605868e-04,
|
|
-2.948678e-04, -1.198237e-04, 4.924605e-05, 1.602105e-04,
|
|
3.162385e-04, 6.700790e-04, 9.868707e-04, 8.484383e-04,
|
|
4.327767e-04, 3.108105e-04, 4.244343e-04, 5.422112e-04,
|
|
7.239584e-04, 1.008546e-03, 1.265120e-03, 1.447669e-03,
|
|
1.436084e-03, 1.271058e-03, 8.684017e-04, 4.149990e-04,
|
|
2.143449e-04, 2.508474e-04, 3.018488e-04, 2.782424e-04,
|
|
2.369677e-04, 3.040710e-04, 3.242530e-04, 2.599912e-04,
|
|
2.211208e-04, 1.311762e-04, -9.807519e-05, -2.752687e-04,
|
|
-3.114068e-04, -2.832832e-04, -3.900219e-04, -6.142824e-04,
|
|
-8.507833e-04, -1.055882e-03]],
|
|
[[ 3.971702e-04, -2.164055e-04, 1.562327e-03, 2.695718e-03,
|
|
2.374928e-03, 2.145125e-03, 1.870762e-03, 1.852614e-03,
|
|
2.074345e-03, 2.312302e-03, 2.222824e-03, 1.876336e-03,
|
|
1.609606e-03, 1.420574e-03, 1.193270e-03, 9.592943e-04,
|
|
1.132237e-03, 1.776782e-03, 2.258269e-03, 1.945908e-03,
|
|
9.930646e-04, 1.733529e-04, -2.533881e-04, -3.138177e-04,
|
|
3.226010e-04, 1.859203e-03, 3.879325e-03, 5.267750e-03,
|
|
5.101699e-03, 3.609465e-03, 1.653315e-03, 2.709297e-04,
|
|
-3.190451e-05, 5.129501e-04, 1.224789e-03, 1.397457e-03,
|
|
1.110794e-03, 8.736057e-04, 8.860155e-04, 1.055910e-03,
|
|
1.100855e-03, 8.834896e-04, 3.825913e-04, -3.267327e-05,
|
|
6.586456e-05, 7.147206e-04, 1.394876e-03, 1.535393e-03,
|
|
1.192172e-03, 7.061819e-04],
|
|
[-6.897163e-03, -6.407891e-03, -4.015491e-03, -3.082125e-03,
|
|
-3.434983e-03, -3.885052e-03, -4.456392e-03, -4.296550e-03,
|
|
-3.861045e-03, -3.553474e-03, -3.547473e-03, -3.800863e-03,
|
|
-4.123025e-03, -4.237277e-03, -4.244958e-03, -4.263899e-03,
|
|
-3.808572e-03, -2.811858e-03, -2.147519e-03, -2.516703e-03,
|
|
-3.550721e-03, -4.353373e-03, -4.846224e-03, -4.960613e-03,
|
|
-4.273535e-03, -2.714785e-03, -7.043980e-04, 6.689885e-04,
|
|
5.069164e-04, -9.122533e-04, -2.816979e-03, -4.124952e-03,
|
|
-4.235019e-03, -3.491365e-03, -2.676077e-03, -2.381226e-03,
|
|
-2.492559e-03, -2.634424e-03, -2.632524e-03, -2.528266e-03,
|
|
-2.536691e-03, -2.746170e-03, -3.187869e-03, -3.553530e-03,
|
|
-3.462211e-03, -2.862707e-03, -2.273719e-03, -2.201617e-03,
|
|
-2.565818e-03, -3.044683e-03]]]),
|
|
"12.0": torch.tensor([[[ 2.237194e-03, 2.508208e-03, 3.986347e-03, 4.020395e-03,
|
|
2.889890e-03, 2.733388e-03, 2.684146e-03, 2.251372e-03,
|
|
1.787451e-03, 1.720550e-03, 1.689184e-03, 1.495478e-03,
|
|
1.321027e-03, 1.185375e-03, 1.098422e-03, 1.055453e-03,
|
|
9.591801e-04, 9.328910e-04, 1.026154e-03, 1.031992e-03,
|
|
9.155220e-04, 9.732856e-04, 1.282264e-03, 1.624059e-03,
|
|
1.920021e-03, 2.333685e-03, 2.730524e-03, 2.919153e-03,
|
|
2.856711e-03, 2.632692e-03, 2.256703e-03, 1.901129e-03,
|
|
1.684760e-03, 1.638201e-03, 1.644909e-03, 1.569378e-03,
|
|
1.448412e-03, 1.478291e-03, 1.580583e-03, 1.633777e-03,
|
|
1.597190e-03, 1.475462e-03, 1.242885e-03, 1.065243e-03,
|
|
1.052842e-03, 1.103825e-03, 1.059115e-03, 9.251673e-04,
|
|
7.235570e-04, 5.053390e-04],
|
|
[-4.534880e-03, -3.111026e-03, -1.486247e-03, -1.739966e-03,
|
|
-2.399862e-03, -2.583335e-03, -3.157276e-03, -3.517166e-03,
|
|
-3.598212e-03, -3.303007e-03, -3.037215e-03, -2.982930e-03,
|
|
-3.026671e-03, -2.958387e-03, -2.836909e-03, -2.775315e-03,
|
|
-2.719575e-03, -2.431532e-03, -2.090512e-03, -2.095603e-03,
|
|
-2.366266e-03, -2.404480e-03, -2.235661e-03, -2.063206e-03,
|
|
-1.888533e-03, -1.640449e-03, -1.407782e-03, -1.250053e-03,
|
|
-1.275359e-03, -1.373277e-03, -1.601508e-03, -1.838720e-03,
|
|
-1.876643e-03, -1.736149e-03, -1.622051e-03, -1.578928e-03,
|
|
-1.564748e-03, -1.455850e-03, -1.391748e-03, -1.418254e-03,
|
|
-1.462577e-03, -1.554713e-03, -1.730076e-03, -1.829485e-03,
|
|
-1.816249e-03, -1.772218e-03, -1.855736e-03, -2.013720e-03,
|
|
-2.196174e-03, -2.378810e-03]],
|
|
[[ 8.993230e-04, 6.808847e-04, 2.595528e-03, 3.586462e-03,
|
|
3.023965e-03, 2.479527e-03, 1.868662e-03, 1.565682e-03,
|
|
1.563900e-03, 1.666364e-03, 1.715061e-03, 1.609638e-03,
|
|
1.294764e-03, 8.647116e-04, 5.122397e-04, 2.899101e-04,
|
|
3.817413e-04, 8.303743e-04, 1.253686e-03, 1.179640e-03,
|
|
6.591807e-04, 1.167982e-04, -3.405492e-04, -5.258832e-04,
|
|
-4.165239e-05, 1.393227e-03, 3.473584e-03, 4.953051e-03,
|
|
4.779391e-03, 3.182305e-03, 1.140233e-03, -2.133392e-04,
|
|
-4.233644e-04, 2.426380e-04, 1.126914e-03, 1.557022e-03,
|
|
1.490265e-03, 1.264647e-03, 1.170405e-03, 1.237709e-03,
|
|
1.112253e-03, 6.990263e-04, 1.700171e-04, -1.761244e-04,
|
|
1.852706e-05, 8.140961e-04, 1.621285e-03, 1.813497e-03,
|
|
1.394625e-03, 7.860070e-04],
|
|
[-4.677887e-03, -3.966209e-03, -1.634288e-03, -8.592710e-04,
|
|
-1.395248e-03, -2.189968e-03, -3.198638e-03, -3.410639e-03,
|
|
-3.241918e-03, -3.051681e-03, -2.845973e-03, -2.786646e-03,
|
|
-3.078280e-03, -3.367662e-03, -3.450923e-03, -3.427895e-03,
|
|
-3.058358e-03, -2.258006e-03, -1.607386e-03, -1.647450e-03,
|
|
-2.164357e-03, -2.647080e-03, -3.110953e-03, -3.304542e-03,
|
|
-2.798792e-03, -1.407999e-03, 5.630683e-04, 1.961336e-03,
|
|
1.813856e-03, 3.529640e-04, -1.526076e-03, -2.695498e-03,
|
|
-2.702039e-03, -1.889018e-03, -9.337939e-04, -3.885011e-04,
|
|
-2.970786e-04, -4.415356e-04, -5.492531e-04, -5.430978e-04,
|
|
-7.051138e-04, -1.102020e-03, -1.577104e-03, -1.846151e-03,
|
|
-1.623901e-03, -8.853760e-04, -1.772702e-04, -4.866864e-05,
|
|
-4.633263e-04, -1.017192e-03]]]),
|
|
"24.0": torch.tensor(
|
|
[
|
|
[
|
|
[0.0004, 0.0008, 0.0024, 0.0024, 0.0013, 0.0013, 0.0013, 0.0009, 0.0005, 0.0005, 0.0006, 0.0005, 0.0005, 0.0003, 0.0003, 0.0003, 0.0002, 0.0002, 0.0003, 0.0004, 0.0003, 0.0004, 0.0008, 0.0012, 0.0015, 0.0018, 0.0021, 0.0022, 0.0021, 0.0019, 0.0016, 0.0014, 0.0012, 0.0011, 0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0014, 0.0014, 0.0013, 0.0011, 0.0009, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0007],
|
|
[ -0.0055, -0.0040, -0.0024, -0.0026, -0.0031, -0.0031, -0.0036, -0.0039, -0.0039, -0.0035, -0.0031, -0.0029, -0.0028, -0.0027, -0.0026, -0.0024, -0.0023, -0.0020, -0.0017, -0.0016, -0.0017, -0.0017, -0.0015, -0.0012, -0.0010, -0.0008, -0.0006, -0.0004, -0.0004, -0.0005, -0.0006, -0.0007, -0.0006, -0.0004, -0.0002, -0.0001, 0.0001, 0.0002, 0.0003, 0.0004, 0.0004, 0.0003, 0.0001, 0.0001, 0.0000, 0.0001, 0.0000, -0.0001, -0.0002, -0.0004],
|
|
],
|
|
[
|
|
[-0.0024, -0.0029, -0.0009, 0.0002, -0.0002, -0.0007, -0.0012, -0.0013, -0.0012, -0.0011, -0.0011, -0.0012, -0.0016, -0.0021, -0.0024, -0.0026, -0.0024, -0.0018, -0.0013, -0.0015, -0.0022, -0.0029, -0.0035, -0.0038, -0.0031, -0.0015, 0.0008, 0.0025, 0.0023, 0.0006, -0.0016, -0.0030, -0.0032, -0.0024, -0.0015, -0.0010, -0.0009, -0.0011, -0.0010, -0.0009, -0.0010, -0.0014, -0.0020, -0.0023, -0.0020, -0.0011, -0.0001, 0.0001, -0.0003, -0.0009],
|
|
[-0.0086, -0.0081, -0.0059, -0.0050, -0.0053, -0.0061, -0.0071, -0.0071, -0.0069, -0.0067, -0.0066, -0.0066, -0.0070, -0.0073, -0.0074, -0.0073, -0.0069, -0.0060, -0.0053, -0.0055, -0.0061, -0.0067, -0.0072, -0.0074, -0.0067, -0.0052, -0.0031, -0.0015, -0.0016, -0.0029, -0.0048, -0.0059, -0.0059, -0.0051, -0.0041, -0.0036, -0.0034, -0.0034, -0.0034, -0.0033, -0.0035, -0.0039, -0.0043, -0.0046, -0.0043, -0.0035, -0.0027, -0.0025, -0.0029, -0.0034],
|
|
],
|
|
]
|
|
)
|
|
}
|
|
}
|
|
# ---- error over whole batch
|
|
EXPECTED_CODEC_ERROR_BATCH = {
|
|
"facebook/encodec_24khz": {
|
|
"1.5": 0.0011174238752573729,
|
|
"3.0": 0.0009308119188062847,
|
|
"6.0": 0.0008,
|
|
"12.0": 0.0006830253987573087,
|
|
"24.0": 0.000642190920189023,
|
|
},
|
|
"facebook/encodec_48khz": {
|
|
"3.0": 0.00039895583176985383,
|
|
"6.0": 0.0003249854489695281,
|
|
"12.0": 0.0002540576097089797,
|
|
"24.0": 0.00021899679268244654,
|
|
}
|
|
}
|
|
# fmt: on
|
|
|
|
|
|
@slow
|
|
@require_torch
|
|
class EncodecIntegrationTest(unittest.TestCase):
|
|
@parameterized.expand(
|
|
[
|
|
(f"{os.path.basename(model_id)}_{bandwidth.replace('.', 'p')}", model_id, bandwidth)
|
|
for model_id, v in EXPECTED_ENCODER_CODES.items()
|
|
for bandwidth in v
|
|
]
|
|
)
|
|
def test_integration(self, name, model_id, bandwidth):
|
|
# load model
|
|
model = EncodecModel.from_pretrained(model_id).to(torch_device)
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
# load audio
|
|
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
|
audio_array = librispeech_dummy[0]["audio"]["array"]
|
|
if model.config.audio_channels > 1:
|
|
audio_array = np.array([audio_array] * model.config.audio_channels)
|
|
inputs = processor(
|
|
raw_audio=audio_array,
|
|
sampling_rate=processor.sampling_rate,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
).to(torch_device)
|
|
|
|
model = model.eval()
|
|
with torch.no_grad():
|
|
# Compare encoder outputs with expected values
|
|
encoded_frames = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
|
|
codes = torch.cat([encoded[0] for encoded in encoded_frames["audio_codes"]], dim=-1).unsqueeze(0)
|
|
torch.testing.assert_close(
|
|
codes[..., : EXPECTED_ENCODER_CODES[model_id][bandwidth].shape[-1]],
|
|
EXPECTED_ENCODER_CODES[model_id][bandwidth].to(torch_device),
|
|
rtol=1e-4,
|
|
atol=1e-4,
|
|
)
|
|
if EXPECTED_ENCODER_SCALES[model_id][bandwidth] is not None:
|
|
scales = torch.tensor([encoded[0].squeeze() for encoded in encoded_frames["audio_scales"]])
|
|
torch.testing.assert_close(scales, EXPECTED_ENCODER_SCALES[model_id][bandwidth], rtol=1e-4, atol=1e-4)
|
|
|
|
# Compare decoder outputs with expected values
|
|
decoded_frames = model.decode(
|
|
encoded_frames["audio_codes"],
|
|
encoded_frames["audio_scales"],
|
|
inputs["padding_mask"],
|
|
last_frame_pad_length=encoded_frames["last_frame_pad_length"],
|
|
)
|
|
torch.testing.assert_close(
|
|
decoded_frames["audio_values"][0][..., : EXPECTED_DECODER_OUTPUTS[model_id][bandwidth].shape[-1]],
|
|
EXPECTED_DECODER_OUTPUTS[model_id][bandwidth].to(torch_device),
|
|
rtol=1e-4,
|
|
atol=1e-4,
|
|
)
|
|
|
|
# Compare codec error with expected values
|
|
codec_error = compute_rmse(decoded_frames["audio_values"], inputs["input_values"])
|
|
torch.testing.assert_close(codec_error, EXPECTED_CODEC_ERROR[model_id][bandwidth], rtol=1e-4, atol=1e-4)
|
|
|
|
# make sure forward and enc-dec give same result
|
|
full_enc = model(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
|
|
torch.testing.assert_close(
|
|
full_enc["audio_values"],
|
|
decoded_frames["audio_values"],
|
|
rtol=1e-4,
|
|
atol=1e-4,
|
|
)
|
|
|
|
@parameterized.expand(
|
|
[
|
|
(f"{os.path.basename(model_id)}_{bandwidth.replace('.', 'p')}", model_id, bandwidth)
|
|
for model_id, v in EXPECTED_ENCODER_CODES_BATCH.items()
|
|
for bandwidth in v
|
|
]
|
|
)
|
|
def test_batch(self, name, model_id, bandwidth):
|
|
# load model
|
|
model = EncodecModel.from_pretrained(model_id).to(torch_device)
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
# load audio
|
|
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
|
n_channels = model.config.audio_channels
|
|
if n_channels == 1:
|
|
audio_samples = [audio_sample["array"] for audio_sample in librispeech_dummy[-2:]["audio"]]
|
|
else:
|
|
audio_samples = []
|
|
for _sample in librispeech_dummy[-2:]["audio"]:
|
|
# concatenate mono channels to target number of channels
|
|
audio_array = np.concatenate([_sample["array"][np.newaxis]] * n_channels, axis=0)
|
|
audio_samples.append(audio_array)
|
|
inputs = processor(
|
|
raw_audio=audio_samples,
|
|
sampling_rate=processor.sampling_rate,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
).to(torch_device)
|
|
|
|
# apply model
|
|
model = model.eval()
|
|
with torch.no_grad():
|
|
# Compare encoder outputs with expected values
|
|
encoded_frames = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
|
|
codes = encoded_frames["audio_codes"].permute(1, 2, 0, 3)
|
|
codes = codes.reshape(codes.size(0), codes.size(1), -1)
|
|
torch.testing.assert_close(
|
|
codes[..., : EXPECTED_ENCODER_CODES_BATCH[model_id][bandwidth].shape[-1]],
|
|
EXPECTED_ENCODER_CODES_BATCH[model_id][bandwidth].to(torch_device),
|
|
rtol=1e-4,
|
|
atol=1e-4,
|
|
)
|
|
if EXPECTED_ENCODER_SCALES_BATCH[model_id][bandwidth] is not None:
|
|
scales = torch.stack(encoded_frames["audio_scales"])
|
|
torch.testing.assert_close(
|
|
scales, EXPECTED_ENCODER_SCALES_BATCH[model_id][bandwidth].to(torch_device), rtol=1e-4, atol=1e-4
|
|
)
|
|
|
|
# Compare decoder outputs with expected values
|
|
decoded_frames = model.decode(
|
|
encoded_frames["audio_codes"],
|
|
encoded_frames["audio_scales"],
|
|
inputs["padding_mask"],
|
|
last_frame_pad_length=encoded_frames["last_frame_pad_length"],
|
|
)
|
|
torch.testing.assert_close(
|
|
decoded_frames["audio_values"][..., : EXPECTED_DECODER_OUTPUTS_BATCH[model_id][bandwidth].shape[-1]],
|
|
EXPECTED_DECODER_OUTPUTS_BATCH[model_id][bandwidth].to(torch_device),
|
|
rtol=1e-4,
|
|
atol=1e-4,
|
|
)
|
|
|
|
# Compare codec error with expected values
|
|
codec_error = compute_rmse(decoded_frames["audio_values"], inputs["input_values"])
|
|
torch.testing.assert_close(
|
|
codec_error, EXPECTED_CODEC_ERROR_BATCH[model_id][bandwidth], rtol=1e-4, atol=1e-4
|
|
)
|
|
|
|
# make sure forward and enc-dec give same result
|
|
input_values_dec = model(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
|
|
torch.testing.assert_close(
|
|
input_values_dec["audio_values"], decoded_frames["audio_values"], rtol=1e-4, atol=1e-4
|
|
)
|