Files
2025-10-09 16:47:16 +08:00

1372 lines
74 KiB
Python

# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Encodec model."""
import copy
import inspect
import os
import tempfile
import unittest
import numpy as np
from datasets import Audio, load_dataset
from parameterized import parameterized
from transformers import AutoProcessor, EncodecConfig
from transformers.testing_utils import (
is_torch_available,
require_torch,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import EncodecFeatureExtractor, EncodecModel
def prepare_inputs_dict(
config,
input_ids=None,
input_values=None,
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if input_ids is not None:
encoder_dict = {"input_ids": input_ids}
else:
encoder_dict = {"input_values": input_values}
decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {}
return {**encoder_dict, **decoder_dict}
@require_torch
class EncodecModelTester:
def __init__(
self,
parent,
# `batch_size` needs to be an even number if the model has some outputs with batch dim != 0.
batch_size=12,
num_channels=2,
is_training=False,
intermediate_size=40,
hidden_size=32,
num_filters=8,
num_residual_layers=1,
upsampling_ratios=[8, 4],
num_lstm_layers=1,
codebook_size=64,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.is_training = is_training
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
self.num_filters = num_filters
self.num_residual_layers = num_residual_layers
self.upsampling_ratios = upsampling_ratios
self.num_lstm_layers = num_lstm_layers
self.codebook_size = codebook_size
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def prepare_config_and_inputs_for_model_class(self, model_class):
config, inputs_dict = self.prepare_config_and_inputs()
inputs_dict["audio_codes"] = ids_tensor([1, self.batch_size, 1, self.num_channels], self.codebook_size).type(
torch.int32
)
inputs_dict["audio_scales"] = [None]
return config, inputs_dict
def prepare_config_and_inputs_for_normalization(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
config = self.get_config()
config.normalize = True
processor = EncodecFeatureExtractor(feature_size=config.audio_channels, sampling_rate=config.sampling_rate)
input_values = input_values.tolist()
inputs_dict = processor(
input_values, sampling_rate=config.sampling_rate, padding=True, return_tensors="pt"
).to(torch_device)
return config, inputs_dict
def get_config(self):
return EncodecConfig(
audio_channels=self.num_channels,
chunk_in_sec=None,
hidden_size=self.hidden_size,
num_filters=self.num_filters,
num_residual_layers=self.num_residual_layers,
upsampling_ratios=self.upsampling_ratios,
num_lstm_layers=self.num_lstm_layers,
codebook_size=self.codebook_size,
)
def create_and_check_model_forward(self, config, inputs_dict):
model = EncodecModel(config=config).to(torch_device).eval()
result = model(**inputs_dict)
self.parent.assertEqual(
result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size)
)
@require_torch
class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (EncodecModel,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_headmasking = False
test_resize_embeddings = False
pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does not have attention and does not support returning hidden states
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if "output_attentions" in inputs_dict:
inputs_dict.pop("output_attentions")
if "output_hidden_states" in inputs_dict:
inputs_dict.pop("output_hidden_states")
return inputs_dict
def setUp(self):
self.model_tester = EncodecModelTester(self)
self.config_tester = ConfigTester(
self, config_class=EncodecConfig, hidden_size=37, common_properties=[], has_text_modality=False
)
def test_config(self):
self.config_tester.run_common_tests()
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_values", "padding_mask", "bandwidth"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@unittest.skip(reason="The EncodecModel is not transformers based, thus it does not have `inputs_embeds` logics")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="The EncodecModel is not transformers based, thus it does not have `inputs_embeds` logics")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(
reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
)
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(
reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
)
def test_torchscript_output_attentions(self):
pass
@unittest.skip(
reason="The EncodecModel is not transformers based, thus it does not have the usual `hidden_states` logic"
)
def test_torchscript_output_hidden_state(self):
pass
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
main_input_name = model_class.main_input_name
try:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict:
if key not in model_state_dict:
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
@unittest.skip(
reason="The EncodecModel is not transformers based, thus it does not have the usual `attention` logic"
)
def test_attention_outputs(self):
pass
def test_feed_forward_chunking(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
# original_config.norm_type = "time_group_norm"
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
config.chunk_length_s = None
config.overlap = None
config.sampling_rate = 20
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
inputs["input_values"] = inputs["input_values"].repeat(1, 1, 10)
hidden_states_no_chunk = model(**inputs)[1]
torch.manual_seed(0)
config.chunk_length_s = 2
config.overlap = 0
config.sampling_rate = 20
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(**inputs)[1]
torch.testing.assert_close(hidden_states_no_chunk, hidden_states_with_chunk, rtol=1e-1, atol=1e-2)
@unittest.skip(
reason="The EncodecModel is not transformers based, thus it does not have the usual `hidden_states` logic"
)
def test_hidden_states_output(self):
pass
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_determinism(first, second):
# outputs are not tensors but list (since each sequence don't have the same frame_length)
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_determinism(tensor1, tensor2)
else:
check_determinism(first, second)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def assert_nested_tensors_close(a, b):
if isinstance(a, (tuple, list)) and isinstance(b, (tuple, list)):
assert len(a) == len(b), f"Length mismatch: {len(a)} vs {len(b)}"
for i, (x, y) in enumerate(zip(a, b)):
assert_nested_tensors_close(x, y)
elif torch.is_tensor(a) and torch.is_tensor(b):
a_clean = set_nan_tensor_to_zero(a)
b_clean = set_nan_tensor_to_zero(b)
assert torch.allclose(a_clean, b_clean, atol=1e-5), (
"Tuple and dict output are not equal. Difference:"
f" Max diff: {torch.max(torch.abs(a_clean - b_clean))}. "
f"Tuple has nan: {torch.isnan(a).any()} and inf: {torch.isinf(a)}. "
f"Dict has nan: {torch.isnan(b).any()} and inf: {torch.isinf(b)}."
)
else:
raise ValueError(f"Mismatch between {a} vs {b}")
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
self.assertTrue(isinstance(tuple_output, tuple))
self.assertTrue(isinstance(dict_output, dict))
# cast dict_output.values() to list as it is a odict_values object
assert_nested_tensors_close(tuple_output, list(dict_output.values()))
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = ["conv"]
ignore_init = ["lstm"]
if param.requires_grad:
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif not any(x in name for x in ignore_init):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_identity_shortcut(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_conv_shortcut = False
self.model_tester.create_and_check_model_forward(config, inputs_dict)
def test_model_forward_with_normalization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_normalization()
self.model_tester.create_and_check_model_forward(config, inputs_dict)
def normalize(arr):
norm = np.linalg.norm(arr)
normalized_arr = arr / norm
return normalized_arr
def compute_rmse(arr1, arr2):
arr1_np = arr1.cpu().numpy().squeeze()
arr2_np = arr2.cpu().numpy().squeeze()
max_length = min(arr1.shape[-1], arr2.shape[-1])
arr1_np = arr1_np[..., :max_length]
arr2_np = arr2_np[..., :max_length]
arr1_normalized = normalize(arr1_np)
arr2_normalized = normalize(arr2_np)
return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())
"""
Integration tests for the Encodec model.
Code for expected output can be found below:
- test_integration: https://gist.github.com/ebezzam/2a34e249e729881130d1f5a42229d31f#file-test_encodec-py
- test_batch: https://gist.github.com/ebezzam/2a34e249e729881130d1f5a42229d31f#file-test_encodec_batch-py
"""
# fmt: off
# first key is model_id from hub, second key is bandwidth
# -- test_integration
EXPECTED_ENCODER_CODES = {
"facebook/encodec_24khz": {
"1.5": torch.tensor([[[ 62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424]]]),
"3.0": torch.tensor(
[
[
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
],
]
),
"6.0": torch.tensor(
[
[
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
[528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
[1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
[896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
[899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
],
]
),
"12.0": torch.tensor(
[
[
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
[528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
[1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
[896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
[899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
[827, 807, 938, 320, 699, 470, 909, 628, 301, 827],
[963, 801, 630, 477, 717, 354, 205, 359, 874, 744],
[1000, 1000, 388, 1000, 408, 740, 568, 364, 709, 843],
[413, 835, 382, 840, 742, 1019, 375, 962, 835, 742],
[971, 410, 998, 485, 798, 410, 351, 485, 485, 920],
[848, 694, 662, 784, 848, 427, 1022, 848, 920, 694],
[420, 911, 889, 911, 993, 776, 948, 477, 911, 911],
[587, 755, 834, 962, 860, 425, 982, 982, 425, 461],
],
]
),
"24.0": torch.tensor(
[
[
[62, 835, 835, 835, 835, 835, 835, 835, 408, 408],
[1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424],
[786, 678, 821, 786, 36, 36, 786, 212, 937, 937],
[741, 741, 741, 993, 741, 1018, 993, 919, 741, 741],
[528, 446, 198, 190, 446, 622, 646, 448, 646, 448],
[1011, 140, 185, 986, 683, 986, 435, 41, 140, 939],
[896, 772, 562, 772, 485, 528, 896, 853, 562, 772],
[899, 975, 468, 468, 468, 701, 1013, 828, 518, 899],
[827, 807, 938, 320, 699, 470, 909, 628, 301, 827],
[963, 801, 630, 477, 717, 354, 205, 359, 874, 744],
[1000, 1000, 388, 1000, 408, 740, 568, 364, 709, 843],
[413, 835, 382, 840, 742, 1019, 375, 962, 835, 742],
[971, 410, 998, 485, 798, 410, 351, 485, 485, 920],
[848, 694, 662, 784, 848, 427, 1022, 848, 920, 694],
[420, 911, 889, 911, 993, 776, 948, 477, 911, 911],
[587, 755, 834, 962, 860, 425, 982, 982, 425, 461],
[270, 160, 26, 131, 597, 506, 670, 637, 248, 160],
[ 15, 215, 134, 69, 215, 155, 1012, 1009, 260, 417],
[580, 561, 686, 896, 497, 637, 580, 245, 896, 264],
[511, 239, 560, 691, 571, 627, 571, 571, 258, 619],
[591, 942, 591, 251, 250, 250, 857, 486, 295, 295],
[565, 546, 654, 301, 301, 623, 639, 568, 565, 282],
[539, 317, 639, 539, 651, 539, 538, 640, 615, 615],
[637, 556, 637, 582, 640, 515, 515, 632, 254, 613],
[305, 643, 500, 550, 522, 500, 550, 561, 522, 305],
[954, 456, 584, 755, 505, 782, 661, 671, 497, 505],
[577, 464, 637, 647, 552, 552, 624, 647, 624, 647],
[728, 748, 931, 608, 538, 1015, 294, 294, 666, 538],
[602, 535, 666, 665, 655, 979, 574, 535, 571, 781],
[321, 620, 557, 566, 511, 910, 672, 623, 853, 674],
[621, 556, 947, 474, 610, 752, 1002, 597, 474, 474],
[605, 948, 657, 588, 485, 633, 459, 968, 939, 325],
],
]
),
},
"facebook/encodec_48khz": {
"3.0": torch.tensor([[[214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
[989, 989, 611, 77, 77, 989, 976, 976, 976, 77]]]),
"6.0": torch.tensor([[[ 214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
[ 989, 989, 611, 77, 77, 989, 976, 976, 976, 77],
[ 977, 1009, 538, 925, 925, 977, 1022, 1022, 1022, 925],
[ 376, 1012, 1023, 725, 725, 1023, 376, 962, 376, 847]]]),
"12.0": torch.tensor([[[ 214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
[ 989, 989, 611, 77, 77, 989, 976, 976, 976, 77],
[ 977, 1009, 538, 925, 925, 977, 1022, 1022, 1022, 925],
[ 376, 1012, 1023, 725, 725, 1023, 376, 962, 376, 847],
[ 979, 1012, 323, 695, 1018, 1023, 979, 1023, 979, 650],
[ 945, 762, 528, 865, 824, 945, 945, 945, 957, 957],
[ 904, 973, 1014, 681, 582, 1014, 1014, 1014, 1014, 681],
[ 229, 392, 796, 392, 977, 1017, 250, 1017, 250, 1017]]]),
"24.0": torch.tensor([[[ 214, 214, 214, 214, 214, 118, 214, 214, 214, 214],
[ 989, 989, 611, 77, 77, 989, 976, 976, 976, 77],
[ 977, 1009, 538, 925, 925, 977, 1022, 1022, 1022, 925],
[ 376, 1012, 1023, 725, 725, 1023, 376, 962, 376, 847],
[ 979, 1012, 323, 695, 1018, 1023, 979, 1023, 979, 650],
[ 945, 762, 528, 865, 824, 945, 945, 945, 957, 957],
[ 904, 973, 1014, 681, 582, 1014, 1014, 1014, 1014, 681],
[ 229, 392, 796, 392, 977, 1017, 250, 1017, 250, 1017],
[ 902, 436, 935, 1011, 1023, 1023, 1023, 154, 1023, 392],
[ 982, 878, 961, 832, 629, 431, 919, 629, 919, 792],
[ 727, 727, 401, 727, 979, 587, 727, 487, 413, 201],
[ 928, 924, 965, 934, 840, 480, 924, 920, 924, 486],
[ 10, 625, 712, 552, 712, 259, 394, 131, 726, 516],
[ 882, 1022, 32, 524, 267, 861, 974, 882, 108, 521],
[ 304, 841, 306, 415, 69, 376, 928, 510, 381, 104],
[ 0, 0, 0, 484, 83, 0, 307, 262, 0, 0]]])
}
}
EXPECTED_ENCODER_SCALES = {
"facebook/encodec_24khz": {
"1.5": None,
"3.0": None,
"6.0": None,
"12.0": None,
"24.0": None
},
"facebook/encodec_48khz": {
"3.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
4.483359e-02, 1.000000e-08]),
"6.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
4.483359e-02, 1.000000e-08]),
"12.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
4.483359e-02, 1.000000e-08]),
"24.0": torch.tensor([5.365404e-02, 8.153407e-02, 6.266369e-02, 6.688326e-02, 5.458422e-02,
4.483359e-02, 1.000000e-08])
}
}
EXPECTED_DECODER_OUTPUTS = {
"facebook/encodec_24khz": {
"1.5": torch.tensor(
[[ 0.0003, -0.0002, -0.0000, -0.0004, 0.0004, 0.0003, -0.0000, 0.0001, 0.0005, 0.0001, -0.0015, -0.0007, -0.0002, -0.0018, -0.0003, 0.0013, 0.0011, 0.0008, 0.0008, 0.0008, 0.0008, 0.0002, -0.0003, -0.0004, -0.0006, -0.0009, -0.0010, -0.0012, -0.0011, -0.0006, -0.0006, -0.0005, 0.0000, 0.0001, 0.0003, 0.0002, -0.0001, -0.0002, -0.0008, -0.0012, -0.0011, -0.0012, -0.0013, -0.0003, 0.0002, 0.0006, 0.0006, 0.0006, 0.0009, 0.0010]]
),
"3.0": torch.tensor(
[[ 0.0003, -0.0002, -0.0000, -0.0004, 0.0004, 0.0003, -0.0000, 0.0001, 0.0006, 0.0002, -0.0015, -0.0008, -0.0002, -0.0018, -0.0003, 0.0013, 0.0011, 0.0008, 0.0008, 0.0008, 0.0008, 0.0002, -0.0003, -0.0004, -0.0005, -0.0008, -0.0010, -0.0012, -0.0011, -0.0006, -0.0006, -0.0005, -0.0000, 0.0001, 0.0003, 0.0002, -0.0001, -0.0002, -0.0008, -0.0013, -0.0011, -0.0013, -0.0014, -0.0003, 0.0002, 0.0006, 0.0006, 0.0006, 0.0009, 0.0010]]
),
"6.0": torch.tensor(
[[ 0.0004, -0.0001, 0.0001, -0.0003, 0.0004, 0.0003, 0.0000, 0.0001, 0.0007, 0.0002, -0.0013, -0.0007, -0.0002, -0.0015, -0.0001, 0.0014, 0.0014, 0.0011, 0.0010, 0.0010, 0.0009, 0.0004, 0.0000, 0.0000, 0.0000, -0.0000, -0.0001, -0.0004, -0.0004, -0.0001, -0.0002, -0.0002, 0.0002, 0.0005, 0.0009, 0.0010, 0.0008, 0.0007, 0.0002, -0.0003, -0.0004, -0.0008, -0.0008, 0.0000, 0.0006, 0.0010, 0.0012, 0.0012, 0.0013, 0.0014]]
),
"12.0": torch.tensor(
[[ 0.0004, -0.0001, 0.0001, -0.0004, 0.0003, 0.0002, -0.0000, 0.0001, 0.0006, 0.0002, -0.0013, -0.0006, -0.0001, -0.0014, 0.0001, 0.0018, 0.0018, 0.0014, 0.0012, 0.0013, 0.0011, 0.0006, 0.0000, 0.0000, -0.0000, -0.0001, -0.0001, -0.0004, -0.0004, -0.0000, -0.0000, -0.0000, 0.0005, 0.0007, 0.0011, 0.0011, 0.0009, 0.0007, 0.0002, -0.0003, -0.0004, -0.0007, -0.0007, 0.0002, 0.0009, 0.0013, 0.0015, 0.0014, 0.0015, 0.0016]]
),
"24.0": torch.tensor(
[[ 0.0005, 0.0001, 0.0004, -0.0001, 0.0003, 0.0002, 0.0000, 0.0001, 0.0007, 0.0005, -0.0011, -0.0005, -0.0001, -0.0018, -0.0000, 0.0021, 0.0019, 0.0013, 0.0011, 0.0012, 0.0012, 0.0006, -0.0000, -0.0001, -0.0000, -0.0000, -0.0001, -0.0004, -0.0004, -0.0000, -0.0001, -0.0002, 0.0003, 0.0004, 0.0008, 0.0007, 0.0006, 0.0007, 0.0001, -0.0004, -0.0003, -0.0006, -0.0008, 0.0004, 0.0011, 0.0015, 0.0016, 0.0015, 0.0016, 0.0018]]
)
},
"facebook/encodec_48khz": {
"3.0": torch.tensor(
[
[0.0034, 0.0028, 0.0037, 0.0041, 0.0029, 0.0022, 0.0021, 0.0020, 0.0021, 0.0023, 0.0021, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0021, 0.0023, 0.0025, 0.0022, 0.0017, 0.0015, 0.0017, 0.0020, 0.0024, 0.0031, 0.0039, 0.0045, 0.0046, 0.0042, 0.0034, 0.0027, 0.0023, 0.0022, 0.0023, 0.0024, 0.0022, 0.0023, 0.0024, 0.0027, 0.0027, 0.0027, 0.0025, 0.0024, 0.0024, 0.0026, 0.0028, 0.0027, 0.0024, 0.0022],
[ -0.0031, -0.0027, -0.0018, -0.0017, -0.0024, -0.0029, -0.0030, -0.0026, -0.0021, -0.0018, -0.0018, -0.0019, -0.0017, -0.0014, -0.0012, -0.0010, -0.0008, -0.0004, -0.0001, -0.0004, -0.0012, -0.0015, -0.0014, -0.0013, -0.0011, -0.0005, 0.0002, 0.0007, 0.0008, 0.0004, -0.0003, -0.0010, -0.0012, -0.0011, -0.0009, -0.0009, -0.0009, -0.0008, -0.0006, -0.0005, -0.0005, -0.0005, -0.0006, -0.0008, -0.0008, -0.0006, -0.0005, -0.0007, -0.0010, -0.0012],
]
),
"6.0": torch.tensor(
[
[0.0052, 0.0049, 0.0057, 0.0058, 0.0048, 0.0043, 0.0042, 0.0041, 0.0041, 0.0042, 0.0040, 0.0038, 0.0038, 0.0038, 0.0037, 0.0037, 0.0037, 0.0037, 0.0038, 0.0037, 0.0035, 0.0034, 0.0036, 0.0039, 0.0043, 0.0047, 0.0053, 0.0057, 0.0057, 0.0055, 0.0050, 0.0046, 0.0043, 0.0041, 0.0042, 0.0042, 0.0041, 0.0041, 0.0042, 0.0043, 0.0043, 0.0043, 0.0041, 0.0040, 0.0040, 0.0041, 0.0042, 0.0042, 0.0040, 0.0039],
[ 0.0001, 0.0006, 0.0013, 0.0011, 0.0005, 0.0001, -0.0001, 0.0001, 0.0003, 0.0005, 0.0005, 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.0010, 0.0013, 0.0015, 0.0014, 0.0010, 0.0008, 0.0010, 0.0012, 0.0015, 0.0019, 0.0023, 0.0026, 0.0026, 0.0024, 0.0020, 0.0016, 0.0013, 0.0013, 0.0014, 0.0015, 0.0015, 0.0016, 0.0017, 0.0017, 0.0017, 0.0016, 0.0015, 0.0013, 0.0013, 0.0013, 0.0013, 0.0012, 0.0010, 0.0009],
]
),
"12.0": torch.tensor(
[
[0.0014, 0.0012, 0.0021, 0.0024, 0.0017, 0.0013, 0.0012, 0.0011, 0.0011, 0.0012, 0.0011, 0.0010, 0.0009, 0.0009, 0.0008, 0.0008, 0.0009, 0.0010, 0.0012, 0.0012, 0.0009, 0.0008, 0.0010, 0.0013, 0.0017, 0.0024, 0.0031, 0.0036, 0.0036, 0.0033, 0.0028, 0.0023, 0.0020, 0.0020, 0.0022, 0.0022, 0.0022, 0.0022, 0.0023, 0.0024, 0.0024, 0.0023, 0.0021, 0.0021, 0.0021, 0.0023, 0.0024, 0.0024, 0.0022, 0.0021],
[ -0.0034, -0.0029, -0.0020, -0.0020, -0.0024, -0.0027, -0.0030, -0.0030, -0.0028, -0.0025, -0.0025, -0.0025, -0.0025, -0.0025, -0.0023, -0.0022, -0.0020, -0.0017, -0.0013, -0.0014, -0.0017, -0.0019, -0.0018, -0.0015, -0.0011, -0.0006, 0.0000, 0.0005, 0.0005, 0.0002, -0.0003, -0.0008, -0.0010, -0.0009, -0.0007, -0.0006, -0.0006, -0.0005, -0.0005, -0.0005, -0.0005, -0.0007, -0.0008, -0.0009, -0.0009, -0.0008, -0.0007, -0.0008, -0.0010, -0.0011],
]
),
"24.0": torch.tensor(
[
[ 0.0010, 0.0008, 0.0018, 0.0021, 0.0014, 0.0011, 0.0009, 0.0007, 0.0006, 0.0006, 0.0005, 0.0003, 0.0003, 0.0002, 0.0001, 0.0001, 0.0001, 0.0002, 0.0002, 0.0001, -0.0002, -0.0004, -0.0003, 0.0000, 0.0005, 0.0011, 0.0018, 0.0022, 0.0022, 0.0018, 0.0012, 0.0007, 0.0004, 0.0003, 0.0004, 0.0006, 0.0006, 0.0007, 0.0007, 0.0009, 0.0008, 0.0007, 0.0005, 0.0004, 0.0004, 0.0006, 0.0007, 0.0007, 0.0005, 0.0004],
[-0.0039, -0.0035, -0.0027, -0.0026, -0.0028, -0.0031, -0.0035, -0.0035, -0.0034, -0.0033, -0.0032, -0.0032, -0.0031, -0.0031, -0.0029, -0.0028, -0.0026, -0.0024, -0.0021, -0.0021, -0.0024, -0.0025, -0.0024, -0.0021, -0.0017, -0.0011, -0.0006, -0.0002, -0.0002, -0.0004, -0.0009, -0.0013, -0.0015, -0.0015, -0.0014, -0.0013, -0.0012, -0.0011, -0.0010, -0.0010, -0.0011, -0.0012, -0.0014, -0.0015, -0.0015, -0.0014, -0.0013, -0.0014, -0.0016, -0.0017],
]
)
}
}
EXPECTED_CODEC_ERROR = {
"facebook/encodec_24khz": {
"1.5": 0.0022229827009141445,
"3.0": 0.001862662611529231,
"6.0": 0.0015231302240863442,
"12.0": 0.0013,
"24.0": 0.0012,
},
"facebook/encodec_48khz": {
"3.0": 0.000840399123262614,
"6.0": 0.0006692984024994075,
"12.0": 0.0005328940460458398,
"24.0": 0.0004473362350836396,
}
}
# -- test_batch
EXPECTED_ENCODER_CODES_BATCH = {
"facebook/encodec_24khz": {
"1.5": torch.tensor(
[
[
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475, 475, 404, 475, 475, 475, 835, 475, 475, 835, 835,
106, 106, 738, 106, 738, 106, 408, 408, 738, 408, 408, 408, 738, 408, 408, 408, 408, 738, 408,
1017, 604, 64, 303, 394, 5, 570, 991, 570, 969, 814],
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007, 729, 1007, 961, 1007, 1007, 961, 969, 1007,
1007, 424, 518, 1007, 544, 1007, 518, 913, 424, 424, 544, 424, 518, 518, 518, 302, 424, 424, 424,
544, 424, 114, 200, 787, 931, 343, 434, 315, 487, 872, 769, 463],
],
[
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 835, 408, 835, 738, 408, 408, 408, 408, 408,
408, 738, 408, 408, 408, 408, 408, 408, 408, 408, 738, 408, 408, 408, 408, 408, 408, 408, 408, 408,
339, 834, 819, 875, 957, 670, 811, 670, 237, 53],
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913, 518, 913, 518, 518, 544, 424, 424, 518, 424, 424,
424, 544, 424, 424, 424, 518, 424, 518, 518, 937, 544, 424, 518, 302, 518, 424, 424, 518, 424, 424,
913, 857, 841, 363, 463, 78, 176, 645, 255, 571],
],
]
),
"3.0": torch.tensor(
[
[
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
],
[
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
],
]
),
"6.0": torch.tensor(
[
[
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
[712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
[939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
[853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
[899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
],
[
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
[373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
[293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
[800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
[916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
],
]
),
"12.0": torch.tensor(
[
[
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
[712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
[939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
[853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
[899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
[817, 470, 588, 675, 675, 588, 960, 927, 909, 466],
[953, 776, 717, 630, 359, 717, 861, 630, 861, 359],
[623, 740, 1000, 388, 420, 388, 740, 818, 958, 743],
[413, 835, 742, 249, 892, 352, 190, 498, 866, 890],
[817, 351, 804, 751, 938, 535, 434, 879, 351, 971],
[792, 495, 935, 848, 792, 795, 942, 935, 723, 531],
[622, 681, 477, 713, 752, 871, 713, 514, 993, 777],
[928, 799, 962, 1005, 860, 439, 312, 922, 982, 922],
],
[
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
[373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
[293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
[800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
[916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
[420, 628, 918, 628, 628, 628, 248, 628, 909, 811],
[736, 717, 994, 974, 477, 874, 963, 979, 355, 979],
[1002, 1002, 894, 875, 388, 709, 534, 408, 881, 709],
[735, 828, 763, 742, 640, 835, 828, 375, 840, 375],
[898, 938, 556, 658, 410, 951, 486, 658, 877, 877],
[ 0, 797, 428, 694, 428, 920, 1022, 1022, 809, 797],
[622, 421, 422, 776, 911, 911, 958, 421, 776, 421],
[1005, 312, 922, 755, 834, 461, 461, 702, 597, 974],
],
]
),
"24.0": torch.tensor(
[
[
[62, 106, 475, 475, 404, 404, 475, 404, 404, 475],
[424, 969, 913, 1007, 544, 1007, 1007, 1007, 969, 1007],
[212, 832, 212, 36, 36, 36, 767, 653, 982, 1016],
[956, 741, 838, 1019, 739, 780, 838, 1019, 1014, 1019],
[712, 862, 712, 448, 528, 646, 446, 373, 694, 373],
[939, 881, 939, 19, 334, 881, 1005, 763, 632, 781],
[853, 464, 772, 782, 782, 983, 890, 874, 983, 782],
[899, 475, 173, 701, 701, 947, 468, 1019, 882, 518],
[817, 470, 588, 675, 675, 588, 960, 927, 909, 466],
[953, 776, 717, 630, 359, 717, 861, 630, 861, 359],
[623, 740, 1000, 388, 420, 388, 740, 818, 958, 743],
[413, 835, 742, 249, 892, 352, 190, 498, 866, 890],
[817, 351, 804, 751, 938, 535, 434, 879, 351, 971],
[792, 495, 935, 848, 792, 795, 942, 935, 723, 531],
[622, 681, 477, 713, 752, 871, 713, 514, 993, 777],
[928, 799, 962, 1005, 860, 439, 312, 922, 982, 922],
[939, 637, 861, 506, 861, 61, 475, 264, 1019, 260],
[166, 215, 69, 69, 890, 69, 284, 828, 396, 180],
[561, 896, 841, 144, 580, 659, 886, 514, 686, 451],
[691, 691, 239, 735, 62, 287, 383, 972, 550, 505],
[451, 811, 238, 251, 250, 841, 734, 329, 551, 846],
[313, 601, 494, 763, 811, 565, 748, 441, 601, 480],
[653, 242, 630, 572, 701, 973, 632, 374, 561, 521],
[984, 987, 419, 454, 386, 507, 532, 636, 515, 671],
[647, 550, 515, 292, 876, 1011, 719, 549, 691, 911],
[683, 536, 656, 603, 698, 867, 987, 857, 886, 491],
[444, 937, 826, 555, 585, 710, 466, 852, 655, 591],
[658, 952, 903, 508, 739, 596, 420, 721, 464, 306],
[665, 334, 765, 532, 618, 278, 836, 838, 517, 597],
[613, 674, 596, 904, 987, 977, 938, 615, 672, 776],
[689, 386, 749, 658, 250, 869, 957, 806, 750, 659],
[652, 509, 910, 826, 566, 622, 951, 696, 900, 895],
],
[
[835, 835, 835, 835, 835, 835, 835, 835, 835, 835],
[857, 857, 544, 518, 937, 518, 913, 913, 518, 913],
[705, 989, 934, 989, 678, 934, 934, 786, 934, 786],
[366, 1018, 398, 398, 398, 398, 673, 741, 398, 741],
[373, 373, 375, 373, 373, 222, 862, 373, 190, 373],
[293, 949, 435, 435, 435, 293, 949, 881, 632, 986],
[800, 528, 528, 853, 782, 485, 772, 900, 528, 853],
[916, 237, 828, 701, 518, 835, 948, 315, 948, 315],
[420, 628, 918, 628, 628, 628, 248, 628, 909, 811],
[736, 717, 994, 974, 477, 874, 963, 979, 355, 979],
[1002, 1002, 894, 875, 388, 709, 534, 408, 881, 709],
[735, 828, 763, 742, 640, 835, 828, 375, 840, 375],
[898, 938, 556, 658, 410, 951, 486, 658, 877, 877],
[ 0, 797, 428, 694, 428, 920, 1022, 1022, 809, 797],
[622, 421, 422, 776, 911, 911, 958, 421, 776, 421],
[1005, 312, 922, 755, 834, 461, 461, 702, 597, 974],
[248, 248, 637, 248, 977, 506, 546, 270, 670, 506],
[547, 447, 15, 134, 1009, 215, 134, 396, 260, 160],
[635, 497, 686, 765, 264, 497, 244, 675, 624, 656],
[864, 571, 616, 511, 588, 781, 525, 258, 674, 503],
[449, 757, 857, 451, 658, 486, 299, 299, 251, 596],
[809, 628, 255, 568, 623, 301, 639, 546, 617, 623],
[551, 497, 908, 539, 661, 710, 640, 539, 646, 315],
[689, 507, 875, 515, 613, 637, 527, 515, 662, 637],
[983, 686, 456, 768, 601, 561, 768, 653, 500, 688],
[493, 566, 664, 782, 683, 683, 721, 603, 323, 497],
[1015, 552, 411, 423, 607, 646, 687, 1018, 689, 607],
[516, 293, 471, 294, 293, 294, 608, 538, 803, 717],
[974, 994, 952, 637, 637, 927, 535, 571, 602, 535],
[776, 789, 476, 944, 652, 959, 589, 679, 321, 623],
[776, 931, 720, 1009, 676, 731, 386, 676, 701, 676],
[684, 543, 716, 392, 661, 517, 792, 588, 922, 676],
],
]
)
},
"facebook/encodec_48khz": {
"3.0": torch.tensor([[[790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
[989, 989, 77, 546, 989, 546, 989, 160, 546, 989]],
[[214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
[289, 289, 989, 764, 289, 289, 882, 882, 882, 882]]]),
"6.0": torch.tensor([[[ 790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
[ 989, 989, 77, 546, 989, 546, 989, 160, 546, 989],
[ 977, 977, 977, 977, 538, 977, 977, 960, 977, 977],
[ 376, 376, 962, 962, 607, 962, 963, 896, 962, 376]],
[[ 214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
[ 289, 289, 989, 764, 289, 289, 882, 882, 882, 882],
[1022, 1022, 471, 925, 821, 821, 267, 925, 925, 267],
[ 979, 992, 914, 921, 0, 0, 1023, 963, 963, 1023]]]),
"12.0": torch.tensor([[[ 790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
[ 989, 989, 77, 546, 989, 546, 989, 160, 546, 989],
[ 977, 977, 977, 977, 538, 977, 977, 960, 977, 977],
[ 376, 376, 962, 962, 607, 962, 963, 896, 962, 376],
[ 979, 979, 979, 1012, 979, 1012, 921, 0, 1002, 695],
[ 824, 1018, 762, 957, 824, 762, 762, 1007, 957, 336],
[ 681, 973, 973, 452, 211, 681, 802, 679, 547, 884],
[ 950, 1017, 1016, 1017, 986, 1017, 229, 607, 1017, 689]],
[[ 214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
[ 289, 289, 989, 764, 289, 289, 882, 882, 882, 882],
[1022, 1022, 471, 925, 821, 821, 267, 925, 925, 267],
[ 979, 992, 914, 921, 0, 0, 1023, 963, 963, 1023],
[ 403, 940, 976, 1018, 677, 1002, 979, 677, 677, 677],
[1018, 794, 762, 444, 485, 485, 974, 548, 548, 1018],
[ 679, 243, 679, 1005, 1005, 973, 1014, 1005, 1005, 1014],
[ 810, 13, 1017, 537, 522, 702, 202, 1017, 1017, 15]]]),
"24.0": torch.tensor(
[
[
[790, 790, 790, 214, 214, 214, 799, 214, 214, 214],
[989, 989, 77, 546, 989, 546, 989, 160, 546, 989],
[977, 977, 977, 977, 538, 977, 977, 960, 977, 977],
[376, 376, 962, 962, 607, 962, 963, 896, 962, 376],
[979, 979, 979, 1012, 979, 1012, 921, 0, 1002, 695],
[824, 1018, 762, 957, 824, 762, 762, 1007, 957, 336],
[681, 973, 973, 452, 211, 681, 802, 679, 547, 884],
[950, 1017, 1016, 1017, 986, 1017, 229, 607, 1017, 689],
[1004, 1011, 669, 1023, 1023, 1023, 905, 297, 810, 970],
[982, 681, 982, 629, 662, 919, 878, 476, 629, 982],
[727, 727, 959, 959, 979, 959, 530, 959, 337, 961],
[924, 456, 924, 486, 924, 959, 102, 924, 805, 924],
[649, 542, 993, 993, 949, 787, 56, 886, 949, 405],
[864, 1022, 1022, 1022, 460, 753, 805, 309, 1022, 32],
[953, 0, 0, 180, 352, 10, 581, 516, 322, 452],
[300, 0, 1020, 307, 0, 543, 924, 627, 258, 262],
],
[
[214, 214, 214, 214, 214, 214, 214, 214, 214, 214],
[289, 289, 989, 764, 289, 289, 882, 882, 882, 882],
[1022, 1022, 471, 925, 821, 821, 267, 925, 925, 267],
[979, 992, 914, 921, 0, 0, 1023, 963, 963, 1023],
[403, 940, 976, 1018, 677, 1002, 979, 677, 677, 677],
[1018, 794, 762, 444, 485, 485, 974, 548, 548, 1018],
[679, 243, 679, 1005, 1005, 973, 1014, 1005, 1005, 1014],
[810, 13, 1017, 537, 522, 702, 202, 1017, 1017, 15],
[728, 252, 970, 984, 971, 950, 673, 902, 1011, 810],
[332, 1014, 476, 854, 1014, 861, 332, 411, 411, 408],
[959, 727, 611, 979, 611, 727, 999, 497, 821, 0],
[995, 698, 924, 688, 102, 510, 924, 970, 344, 961],
[ 81, 516, 847, 924, 10, 240, 1005, 726, 993, 378],
[467, 496, 484, 496, 456, 1022, 337, 600, 456, 1022],
[789, 65, 937, 976, 159, 953, 343, 764, 179, 159],
[ 10, 790, 483, 10, 1020, 352, 848, 333, 83, 848],
],
]
)
}
}
EXPECTED_ENCODER_SCALES_BATCH = {
"facebook/encodec_24khz": {
"1.5": None,
"3.0": None,
"6.0": None,
"12.0": None,
"24.0": None
},
"facebook/encodec_48khz": {
"3.0": torch.tensor([[[1.027247e-01],
[7.877284e-02]],
[[1.014922e-01],
[8.696266e-02]],
[[6.308002e-02],
[7.748771e-02]],
[[6.899278e-02],
[1.045912e-01]],
[[6.440169e-02],
[8.843135e-02]],
[[4.139878e-02],
[1.000000e-08]],
[[5.848629e-02],
[1.000000e-08]],
[[2.329416e-04],
[1.000000e-08]],
[[1.000000e-08],
[1.000000e-08]]]),
"6.0": torch.tensor([[[1.027247e-01],
[7.877284e-02]],
[[1.014922e-01],
[8.696266e-02]],
[[6.308002e-02],
[7.748771e-02]],
[[6.899278e-02],
[1.045912e-01]],
[[6.440169e-02],
[8.843135e-02]],
[[4.139878e-02],
[1.000000e-08]],
[[5.848629e-02],
[1.000000e-08]],
[[2.329416e-04],
[1.000000e-08]],
[[1.000000e-08],
[1.000000e-08]]]),
"12.0": torch.tensor([[[1.027247e-01],
[7.877284e-02]],
[[1.014922e-01],
[8.696266e-02]],
[[6.308002e-02],
[7.748771e-02]],
[[6.899278e-02],
[1.045912e-01]],
[[6.440169e-02],
[8.843135e-02]],
[[4.139878e-02],
[1.000000e-08]],
[[5.848629e-02],
[1.000000e-08]],
[[2.329416e-04],
[1.000000e-08]],
[[1.000000e-08],
[1.000000e-08]]]),
"24.0": torch.tensor([[[1.027247e-01],
[7.877284e-02]],
[[1.014922e-01],
[8.696266e-02]],
[[6.308002e-02],
[7.748771e-02]],
[[6.899278e-02],
[1.045912e-01]],
[[6.440169e-02],
[8.843135e-02]],
[[4.139878e-02],
[1.000000e-08]],
[[5.848629e-02],
[1.000000e-08]],
[[2.329416e-04],
[1.000000e-08]],
[[1.000000e-08],
[1.000000e-08]]])
}
}
EXPECTED_DECODER_OUTPUTS_BATCH = {
"facebook/encodec_24khz": {
"1.5": torch.tensor(
[
[[ 0.0010, 0.0004, 0.0005, 0.0002, 0.0005, -0.0001, -0.0003, -0.0001, 0.0003, 0.0001, -0.0014, -0.0009, -0.0007, -0.0023, -0.0009, 0.0008, 0.0007, 0.0003, 0.0001, 0.0001, 0.0003, -0.0001, -0.0003, -0.0004, -0.0005, -0.0007, -0.0009, -0.0011, -0.0010, -0.0006, -0.0007, -0.0007, -0.0005, -0.0005, -0.0003, -0.0002, -0.0002, -0.0001, -0.0005, -0.0008, -0.0005, -0.0007, -0.0009, -0.0002, 0.0003, 0.0005, 0.0004, 0.0001, 0.0003, 0.0004]],
[[ -0.0001, -0.0000, 0.0003, 0.0001, 0.0005, 0.0001, -0.0006, -0.0002, 0.0002, 0.0002, -0.0031, -0.0004, 0.0006, -0.0066, -0.0032, 0.0044, 0.0025, -0.0019, -0.0017, 0.0001, 0.0019, -0.0010, -0.0014, -0.0009, -0.0007, -0.0009, -0.0019, -0.0024, -0.0019, -0.0001, -0.0017, -0.0022, -0.0004, 0.0005, -0.0014, -0.0023, 0.0002, 0.0015, -0.0022, -0.0033, 0.0024, 0.0009, -0.0041, 0.0000, 0.0030, 0.0020, -0.0015, -0.0018, 0.0014, 0.0007]],
]
),
"3.0": torch.tensor(
[
[[ 0.0013, 0.0007, 0.0009, 0.0005, 0.0006, 0.0002, -0.0001, 0.0000, 0.0005, 0.0003, -0.0012, -0.0006, -0.0003, -0.0019, -0.0003, 0.0015, 0.0013, 0.0009, 0.0008, 0.0007, 0.0008, 0.0004, 0.0001, -0.0000, -0.0001, -0.0002, -0.0003, -0.0004, -0.0004, 0.0001, -0.0000, -0.0000, 0.0003, 0.0003, 0.0005, 0.0005, 0.0004, 0.0005, 0.0001, -0.0003, -0.0002, -0.0004, -0.0006, 0.0003, 0.0009, 0.0012, 0.0013, 0.0012, 0.0014, 0.0015]],
[[ 0.0000, -0.0003, 0.0005, 0.0004, 0.0011, 0.0013, 0.0002, 0.0005, 0.0002, 0.0006, -0.0025, -0.0005, 0.0004, -0.0069, -0.0027, 0.0038, 0.0013, -0.0015, -0.0005, 0.0003, 0.0014, -0.0006, -0.0002, -0.0010, -0.0008, -0.0001, -0.0006, -0.0012, -0.0016, 0.0010, 0.0001, -0.0010, -0.0002, 0.0013, -0.0002, -0.0017, 0.0005, 0.0019, -0.0019, -0.0035, 0.0022, -0.0001, -0.0040, 0.0012, 0.0015, 0.0012, 0.0001, -0.0010, 0.0005, 0.0004]],
]
),
"6.0": torch.tensor(
[
[[ 0.0010, 0.0005, 0.0007, 0.0001, 0.0003, -0.0000, -0.0002, -0.0001, 0.0003, 0.0001, -0.0014, -0.0007, -0.0004, -0.0019, -0.0004, 0.0013, 0.0012, 0.0008, 0.0007, 0.0007, 0.0008, 0.0003, 0.0001, 0.0001, -0.0000, -0.0001, -0.0001, -0.0002, -0.0001, 0.0002, 0.0002, 0.0001, 0.0005, 0.0005, 0.0008, 0.0008, 0.0007, 0.0008, 0.0004, 0.0001, 0.0002, -0.0001, -0.0002, 0.0006, 0.0012, 0.0015, 0.0016, 0.0014, 0.0016, 0.0017]],
[[ -0.0005, -0.0001, 0.0003, 0.0001, 0.0010, 0.0012, 0.0002, 0.0004, 0.0012, 0.0003, -0.0023, -0.0003, -0.0005, -0.0063, -0.0026, 0.0040, 0.0024, -0.0018, -0.0005, 0.0016, 0.0004, -0.0008, 0.0009, 0.0002, -0.0015, -0.0003, 0.0004, -0.0011, -0.0013, 0.0012, 0.0001, -0.0019, 0.0007, 0.0021, -0.0009, -0.0016, 0.0015, 0.0013, -0.0022, -0.0015, 0.0016, -0.0014, -0.0033, 0.0017, 0.0025, -0.0004, -0.0005, 0.0010, 0.0005, 0.0001]],
]
),
"12.0": torch.tensor(
[
[[ 0.0003, 0.0002, 0.0004, -0.0004, -0.0003, -0.0007, -0.0008, -0.0006, -0.0001, -0.0002, -0.0016, -0.0009, -0.0004, -0.0021, -0.0003, 0.0015, 0.0016, 0.0012, 0.0011, 0.0010, 0.0010, 0.0005, 0.0002, 0.0001, 0.0000, -0.0001, -0.0002, -0.0004, -0.0004, 0.0000, -0.0000, -0.0002, 0.0001, 0.0001, 0.0004, 0.0003, 0.0002, 0.0004, -0.0001, -0.0005, -0.0004, -0.0006, -0.0007, 0.0003, 0.0009, 0.0013, 0.0015, 0.0015, 0.0017, 0.0018]],
[[ -0.0008, -0.0003, 0.0003, -0.0001, 0.0008, 0.0013, 0.0004, 0.0008, 0.0015, 0.0006, -0.0021, -0.0001, -0.0003, -0.0062, -0.0022, 0.0043, 0.0028, -0.0013, -0.0002, 0.0017, 0.0010, -0.0001, 0.0008, 0.0001, -0.0010, 0.0003, 0.0008, -0.0006, -0.0007, 0.0012, 0.0003, -0.0013, 0.0007, 0.0019, -0.0002, -0.0013, 0.0011, 0.0016, -0.0016, -0.0017, 0.0014, -0.0006, -0.0029, 0.0011, 0.0028, 0.0006, -0.0004, 0.0005, 0.0008, 0.0003]],
]
),
"24.0": torch.tensor(
[
[[ 0.0009, 0.0004, 0.0007, 0.0002, 0.0004, -0.0001, -0.0003, -0.0002, 0.0002, 0.0001, -0.0015, -0.0009, -0.0006, -0.0024, -0.0005, 0.0016, 0.0014, 0.0010, 0.0009, 0.0008, 0.0008, 0.0004, 0.0001, 0.0000, -0.0001, -0.0002, -0.0003, -0.0006, -0.0006, -0.0003, -0.0005, -0.0006, -0.0003, -0.0004, -0.0001, -0.0002, -0.0003, -0.0001, -0.0006, -0.0011, -0.0008, -0.0010, -0.0012, -0.0000, 0.0007, 0.0011, 0.0012, 0.0011, 0.0013, 0.0014]],
[[ -0.0009, -0.0004, 0.0001, -0.0003, 0.0007, 0.0012, 0.0003, 0.0006, 0.0017, 0.0008, -0.0020, 0.0001, -0.0002, -0.0064, -0.0023, 0.0047, 0.0029, -0.0016, -0.0004, 0.0019, 0.0010, -0.0002, 0.0007, -0.0001, -0.0013, 0.0005, 0.0012, -0.0007, -0.0008, 0.0013, -0.0001, -0.0022, 0.0004, 0.0020, -0.0004, -0.0014, 0.0017, 0.0020, -0.0018, -0.0016, 0.0015, -0.0015, -0.0036, 0.0014, 0.0030, 0.0004, 0.0002, 0.0015, 0.0011, 0.0007]],
]
)
},
"facebook/encodec_48khz": {
"3.0": torch.tensor([[[ 0.005083, 0.004669, 0.005723, 0.005600, 0.004231, 0.003830,
0.003684, 0.003349, 0.003032, 0.003055, 0.002768, 0.002370,
0.002384, 0.002450, 0.002391, 0.002363, 0.002357, 0.002435,
0.002568, 0.002463, 0.002137, 0.002092, 0.002440, 0.002772,
0.003035, 0.003473, 0.003963, 0.004288, 0.004315, 0.004087,
0.003618, 0.003166, 0.002874, 0.002775, 0.002820, 0.002758,
0.002565, 0.002498, 0.002583, 0.002671, 0.002656, 0.002613,
0.002433, 0.002236, 0.002215, 0.002302, 0.002287, 0.002113,
0.001909, 0.001767],
[-0.003928, -0.002733, -0.001330, -0.001914, -0.002927, -0.003272,
-0.003677, -0.003615, -0.003341, -0.002907, -0.002764, -0.002742,
-0.002593, -0.002308, -0.002024, -0.001856, -0.001672, -0.001256,
-0.000929, -0.001217, -0.001864, -0.002118, -0.002025, -0.001932,
-0.001816, -0.001572, -0.001214, -0.000885, -0.000829, -0.000976,
-0.001417, -0.001874, -0.002030, -0.001952, -0.001858, -0.001863,
-0.001895, -0.001843, -0.001801, -0.001792, -0.001812, -0.001865,
-0.002008, -0.002120, -0.002132, -0.002093, -0.002170, -0.002370,
-0.002587, -0.002749]],
[[ 0.004229, 0.003422, 0.005044, 0.006059, 0.005242, 0.004623,
0.004231, 0.004050, 0.004314, 0.004701, 0.004559, 0.004105,
0.003874, 0.003713, 0.003355, 0.003055, 0.003235, 0.003927,
0.004500, 0.004195, 0.003328, 0.002804, 0.002628, 0.002456,
0.002693, 0.003883, 0.005604, 0.006791, 0.006702, 0.005427,
0.003622, 0.002328, 0.002173, 0.002871, 0.003505, 0.003410,
0.002851, 0.002511, 0.002534, 0.002685, 0.002714, 0.002538,
0.002110, 0.001697, 0.001786, 0.002415, 0.002940, 0.002856,
0.002348, 0.001883],
[-0.003444, -0.002916, -0.000590, 0.000157, -0.000702, -0.001472,
-0.002032, -0.001891, -0.001283, -0.000670, -0.000590, -0.000875,
-0.001090, -0.001095, -0.001172, -0.001287, -0.000907, 0.000111,
0.000858, 0.000471, -0.000532, -0.001127, -0.001463, -0.001853,
-0.001762, -0.000666, 0.000964, 0.002054, 0.001914, 0.000743,
-0.000876, -0.001990, -0.001951, -0.001042, -0.000229, -0.000171,
-0.000558, -0.000752, -0.000704, -0.000609, -0.000594, -0.000723,
-0.001085, -0.001455, -0.001374, -0.000795, -0.000350, -0.000480,
-0.000993, -0.001432]]]),
"6.0": torch.tensor([[[ 5.892794e-03, 5.767163e-03, 7.065284e-03, 7.068626e-03,
5.825328e-03, 5.601424e-03, 5.582351e-03, 5.209565e-03,
4.829186e-03, 4.809568e-03, 4.663883e-03, 4.402087e-03,
4.337528e-03, 4.311915e-03, 4.236566e-03, 4.209972e-03,
4.179818e-03, 4.196202e-03, 4.309553e-03, 4.267083e-03,
4.052189e-03, 4.068719e-03, 4.381632e-03, 4.692366e-03,
4.998885e-03, 5.466312e-03, 5.895300e-03, 6.115717e-03,
6.055626e-03, 5.773376e-03, 5.316667e-03, 4.826934e-03,
4.450697e-03, 4.315911e-03, 4.310716e-03, 4.202125e-03,
4.008702e-03, 3.957694e-03, 4.017603e-03, 4.060654e-03,
4.036821e-03, 3.923071e-03, 3.659022e-03, 3.427053e-03,
3.387271e-03, 3.462438e-03, 3.434755e-03, 3.247944e-03,
3.009581e-03, 2.800536e-03],
[-1.867314e-03, -6.082351e-04, 9.374358e-04, 5.555808e-04,
-3.020080e-04, -5.281629e-04, -9.364292e-04, -1.057594e-03,
-9.703087e-04, -6.292185e-04, -4.193477e-04, -3.605868e-04,
-2.948678e-04, -1.198237e-04, 4.924605e-05, 1.602105e-04,
3.162385e-04, 6.700790e-04, 9.868707e-04, 8.484383e-04,
4.327767e-04, 3.108105e-04, 4.244343e-04, 5.422112e-04,
7.239584e-04, 1.008546e-03, 1.265120e-03, 1.447669e-03,
1.436084e-03, 1.271058e-03, 8.684017e-04, 4.149990e-04,
2.143449e-04, 2.508474e-04, 3.018488e-04, 2.782424e-04,
2.369677e-04, 3.040710e-04, 3.242530e-04, 2.599912e-04,
2.211208e-04, 1.311762e-04, -9.807519e-05, -2.752687e-04,
-3.114068e-04, -2.832832e-04, -3.900219e-04, -6.142824e-04,
-8.507833e-04, -1.055882e-03]],
[[ 3.971702e-04, -2.164055e-04, 1.562327e-03, 2.695718e-03,
2.374928e-03, 2.145125e-03, 1.870762e-03, 1.852614e-03,
2.074345e-03, 2.312302e-03, 2.222824e-03, 1.876336e-03,
1.609606e-03, 1.420574e-03, 1.193270e-03, 9.592943e-04,
1.132237e-03, 1.776782e-03, 2.258269e-03, 1.945908e-03,
9.930646e-04, 1.733529e-04, -2.533881e-04, -3.138177e-04,
3.226010e-04, 1.859203e-03, 3.879325e-03, 5.267750e-03,
5.101699e-03, 3.609465e-03, 1.653315e-03, 2.709297e-04,
-3.190451e-05, 5.129501e-04, 1.224789e-03, 1.397457e-03,
1.110794e-03, 8.736057e-04, 8.860155e-04, 1.055910e-03,
1.100855e-03, 8.834896e-04, 3.825913e-04, -3.267327e-05,
6.586456e-05, 7.147206e-04, 1.394876e-03, 1.535393e-03,
1.192172e-03, 7.061819e-04],
[-6.897163e-03, -6.407891e-03, -4.015491e-03, -3.082125e-03,
-3.434983e-03, -3.885052e-03, -4.456392e-03, -4.296550e-03,
-3.861045e-03, -3.553474e-03, -3.547473e-03, -3.800863e-03,
-4.123025e-03, -4.237277e-03, -4.244958e-03, -4.263899e-03,
-3.808572e-03, -2.811858e-03, -2.147519e-03, -2.516703e-03,
-3.550721e-03, -4.353373e-03, -4.846224e-03, -4.960613e-03,
-4.273535e-03, -2.714785e-03, -7.043980e-04, 6.689885e-04,
5.069164e-04, -9.122533e-04, -2.816979e-03, -4.124952e-03,
-4.235019e-03, -3.491365e-03, -2.676077e-03, -2.381226e-03,
-2.492559e-03, -2.634424e-03, -2.632524e-03, -2.528266e-03,
-2.536691e-03, -2.746170e-03, -3.187869e-03, -3.553530e-03,
-3.462211e-03, -2.862707e-03, -2.273719e-03, -2.201617e-03,
-2.565818e-03, -3.044683e-03]]]),
"12.0": torch.tensor([[[ 2.237194e-03, 2.508208e-03, 3.986347e-03, 4.020395e-03,
2.889890e-03, 2.733388e-03, 2.684146e-03, 2.251372e-03,
1.787451e-03, 1.720550e-03, 1.689184e-03, 1.495478e-03,
1.321027e-03, 1.185375e-03, 1.098422e-03, 1.055453e-03,
9.591801e-04, 9.328910e-04, 1.026154e-03, 1.031992e-03,
9.155220e-04, 9.732856e-04, 1.282264e-03, 1.624059e-03,
1.920021e-03, 2.333685e-03, 2.730524e-03, 2.919153e-03,
2.856711e-03, 2.632692e-03, 2.256703e-03, 1.901129e-03,
1.684760e-03, 1.638201e-03, 1.644909e-03, 1.569378e-03,
1.448412e-03, 1.478291e-03, 1.580583e-03, 1.633777e-03,
1.597190e-03, 1.475462e-03, 1.242885e-03, 1.065243e-03,
1.052842e-03, 1.103825e-03, 1.059115e-03, 9.251673e-04,
7.235570e-04, 5.053390e-04],
[-4.534880e-03, -3.111026e-03, -1.486247e-03, -1.739966e-03,
-2.399862e-03, -2.583335e-03, -3.157276e-03, -3.517166e-03,
-3.598212e-03, -3.303007e-03, -3.037215e-03, -2.982930e-03,
-3.026671e-03, -2.958387e-03, -2.836909e-03, -2.775315e-03,
-2.719575e-03, -2.431532e-03, -2.090512e-03, -2.095603e-03,
-2.366266e-03, -2.404480e-03, -2.235661e-03, -2.063206e-03,
-1.888533e-03, -1.640449e-03, -1.407782e-03, -1.250053e-03,
-1.275359e-03, -1.373277e-03, -1.601508e-03, -1.838720e-03,
-1.876643e-03, -1.736149e-03, -1.622051e-03, -1.578928e-03,
-1.564748e-03, -1.455850e-03, -1.391748e-03, -1.418254e-03,
-1.462577e-03, -1.554713e-03, -1.730076e-03, -1.829485e-03,
-1.816249e-03, -1.772218e-03, -1.855736e-03, -2.013720e-03,
-2.196174e-03, -2.378810e-03]],
[[ 8.993230e-04, 6.808847e-04, 2.595528e-03, 3.586462e-03,
3.023965e-03, 2.479527e-03, 1.868662e-03, 1.565682e-03,
1.563900e-03, 1.666364e-03, 1.715061e-03, 1.609638e-03,
1.294764e-03, 8.647116e-04, 5.122397e-04, 2.899101e-04,
3.817413e-04, 8.303743e-04, 1.253686e-03, 1.179640e-03,
6.591807e-04, 1.167982e-04, -3.405492e-04, -5.258832e-04,
-4.165239e-05, 1.393227e-03, 3.473584e-03, 4.953051e-03,
4.779391e-03, 3.182305e-03, 1.140233e-03, -2.133392e-04,
-4.233644e-04, 2.426380e-04, 1.126914e-03, 1.557022e-03,
1.490265e-03, 1.264647e-03, 1.170405e-03, 1.237709e-03,
1.112253e-03, 6.990263e-04, 1.700171e-04, -1.761244e-04,
1.852706e-05, 8.140961e-04, 1.621285e-03, 1.813497e-03,
1.394625e-03, 7.860070e-04],
[-4.677887e-03, -3.966209e-03, -1.634288e-03, -8.592710e-04,
-1.395248e-03, -2.189968e-03, -3.198638e-03, -3.410639e-03,
-3.241918e-03, -3.051681e-03, -2.845973e-03, -2.786646e-03,
-3.078280e-03, -3.367662e-03, -3.450923e-03, -3.427895e-03,
-3.058358e-03, -2.258006e-03, -1.607386e-03, -1.647450e-03,
-2.164357e-03, -2.647080e-03, -3.110953e-03, -3.304542e-03,
-2.798792e-03, -1.407999e-03, 5.630683e-04, 1.961336e-03,
1.813856e-03, 3.529640e-04, -1.526076e-03, -2.695498e-03,
-2.702039e-03, -1.889018e-03, -9.337939e-04, -3.885011e-04,
-2.970786e-04, -4.415356e-04, -5.492531e-04, -5.430978e-04,
-7.051138e-04, -1.102020e-03, -1.577104e-03, -1.846151e-03,
-1.623901e-03, -8.853760e-04, -1.772702e-04, -4.866864e-05,
-4.633263e-04, -1.017192e-03]]]),
"24.0": torch.tensor(
[
[
[0.0004, 0.0008, 0.0024, 0.0024, 0.0013, 0.0013, 0.0013, 0.0009, 0.0005, 0.0005, 0.0006, 0.0005, 0.0005, 0.0003, 0.0003, 0.0003, 0.0002, 0.0002, 0.0003, 0.0004, 0.0003, 0.0004, 0.0008, 0.0012, 0.0015, 0.0018, 0.0021, 0.0022, 0.0021, 0.0019, 0.0016, 0.0014, 0.0012, 0.0011, 0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0014, 0.0014, 0.0013, 0.0011, 0.0009, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0007],
[ -0.0055, -0.0040, -0.0024, -0.0026, -0.0031, -0.0031, -0.0036, -0.0039, -0.0039, -0.0035, -0.0031, -0.0029, -0.0028, -0.0027, -0.0026, -0.0024, -0.0023, -0.0020, -0.0017, -0.0016, -0.0017, -0.0017, -0.0015, -0.0012, -0.0010, -0.0008, -0.0006, -0.0004, -0.0004, -0.0005, -0.0006, -0.0007, -0.0006, -0.0004, -0.0002, -0.0001, 0.0001, 0.0002, 0.0003, 0.0004, 0.0004, 0.0003, 0.0001, 0.0001, 0.0000, 0.0001, 0.0000, -0.0001, -0.0002, -0.0004],
],
[
[-0.0024, -0.0029, -0.0009, 0.0002, -0.0002, -0.0007, -0.0012, -0.0013, -0.0012, -0.0011, -0.0011, -0.0012, -0.0016, -0.0021, -0.0024, -0.0026, -0.0024, -0.0018, -0.0013, -0.0015, -0.0022, -0.0029, -0.0035, -0.0038, -0.0031, -0.0015, 0.0008, 0.0025, 0.0023, 0.0006, -0.0016, -0.0030, -0.0032, -0.0024, -0.0015, -0.0010, -0.0009, -0.0011, -0.0010, -0.0009, -0.0010, -0.0014, -0.0020, -0.0023, -0.0020, -0.0011, -0.0001, 0.0001, -0.0003, -0.0009],
[-0.0086, -0.0081, -0.0059, -0.0050, -0.0053, -0.0061, -0.0071, -0.0071, -0.0069, -0.0067, -0.0066, -0.0066, -0.0070, -0.0073, -0.0074, -0.0073, -0.0069, -0.0060, -0.0053, -0.0055, -0.0061, -0.0067, -0.0072, -0.0074, -0.0067, -0.0052, -0.0031, -0.0015, -0.0016, -0.0029, -0.0048, -0.0059, -0.0059, -0.0051, -0.0041, -0.0036, -0.0034, -0.0034, -0.0034, -0.0033, -0.0035, -0.0039, -0.0043, -0.0046, -0.0043, -0.0035, -0.0027, -0.0025, -0.0029, -0.0034],
],
]
)
}
}
# ---- error over whole batch
EXPECTED_CODEC_ERROR_BATCH = {
"facebook/encodec_24khz": {
"1.5": 0.0011174238752573729,
"3.0": 0.0009308119188062847,
"6.0": 0.0008,
"12.0": 0.0006830253987573087,
"24.0": 0.000642190920189023,
},
"facebook/encodec_48khz": {
"3.0": 0.00039895583176985383,
"6.0": 0.0003249854489695281,
"12.0": 0.0002540576097089797,
"24.0": 0.00021899679268244654,
}
}
# fmt: on
@slow
@require_torch
class EncodecIntegrationTest(unittest.TestCase):
@parameterized.expand(
[
(f"{os.path.basename(model_id)}_{bandwidth.replace('.', 'p')}", model_id, bandwidth)
for model_id, v in EXPECTED_ENCODER_CODES.items()
for bandwidth in v
]
)
def test_integration(self, name, model_id, bandwidth):
# load model
model = EncodecModel.from_pretrained(model_id).to(torch_device)
processor = AutoProcessor.from_pretrained(model_id)
# load audio
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_array = librispeech_dummy[0]["audio"]["array"]
if model.config.audio_channels > 1:
audio_array = np.array([audio_array] * model.config.audio_channels)
inputs = processor(
raw_audio=audio_array,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
padding=True,
).to(torch_device)
model = model.eval()
with torch.no_grad():
# Compare encoder outputs with expected values
encoded_frames = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
codes = torch.cat([encoded[0] for encoded in encoded_frames["audio_codes"]], dim=-1).unsqueeze(0)
torch.testing.assert_close(
codes[..., : EXPECTED_ENCODER_CODES[model_id][bandwidth].shape[-1]],
EXPECTED_ENCODER_CODES[model_id][bandwidth].to(torch_device),
rtol=1e-4,
atol=1e-4,
)
if EXPECTED_ENCODER_SCALES[model_id][bandwidth] is not None:
scales = torch.tensor([encoded[0].squeeze() for encoded in encoded_frames["audio_scales"]])
torch.testing.assert_close(scales, EXPECTED_ENCODER_SCALES[model_id][bandwidth], rtol=1e-4, atol=1e-4)
# Compare decoder outputs with expected values
decoded_frames = model.decode(
encoded_frames["audio_codes"],
encoded_frames["audio_scales"],
inputs["padding_mask"],
last_frame_pad_length=encoded_frames["last_frame_pad_length"],
)
torch.testing.assert_close(
decoded_frames["audio_values"][0][..., : EXPECTED_DECODER_OUTPUTS[model_id][bandwidth].shape[-1]],
EXPECTED_DECODER_OUTPUTS[model_id][bandwidth].to(torch_device),
rtol=1e-4,
atol=1e-4,
)
# Compare codec error with expected values
codec_error = compute_rmse(decoded_frames["audio_values"], inputs["input_values"])
torch.testing.assert_close(codec_error, EXPECTED_CODEC_ERROR[model_id][bandwidth], rtol=1e-4, atol=1e-4)
# make sure forward and enc-dec give same result
full_enc = model(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
torch.testing.assert_close(
full_enc["audio_values"],
decoded_frames["audio_values"],
rtol=1e-4,
atol=1e-4,
)
@parameterized.expand(
[
(f"{os.path.basename(model_id)}_{bandwidth.replace('.', 'p')}", model_id, bandwidth)
for model_id, v in EXPECTED_ENCODER_CODES_BATCH.items()
for bandwidth in v
]
)
def test_batch(self, name, model_id, bandwidth):
# load model
model = EncodecModel.from_pretrained(model_id).to(torch_device)
processor = AutoProcessor.from_pretrained(model_id)
# load audio
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
n_channels = model.config.audio_channels
if n_channels == 1:
audio_samples = [audio_sample["array"] for audio_sample in librispeech_dummy[-2:]["audio"]]
else:
audio_samples = []
for _sample in librispeech_dummy[-2:]["audio"]:
# concatenate mono channels to target number of channels
audio_array = np.concatenate([_sample["array"][np.newaxis]] * n_channels, axis=0)
audio_samples.append(audio_array)
inputs = processor(
raw_audio=audio_samples,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
padding=True,
).to(torch_device)
# apply model
model = model.eval()
with torch.no_grad():
# Compare encoder outputs with expected values
encoded_frames = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
codes = encoded_frames["audio_codes"].permute(1, 2, 0, 3)
codes = codes.reshape(codes.size(0), codes.size(1), -1)
torch.testing.assert_close(
codes[..., : EXPECTED_ENCODER_CODES_BATCH[model_id][bandwidth].shape[-1]],
EXPECTED_ENCODER_CODES_BATCH[model_id][bandwidth].to(torch_device),
rtol=1e-4,
atol=1e-4,
)
if EXPECTED_ENCODER_SCALES_BATCH[model_id][bandwidth] is not None:
scales = torch.stack(encoded_frames["audio_scales"])
torch.testing.assert_close(
scales, EXPECTED_ENCODER_SCALES_BATCH[model_id][bandwidth].to(torch_device), rtol=1e-4, atol=1e-4
)
# Compare decoder outputs with expected values
decoded_frames = model.decode(
encoded_frames["audio_codes"],
encoded_frames["audio_scales"],
inputs["padding_mask"],
last_frame_pad_length=encoded_frames["last_frame_pad_length"],
)
torch.testing.assert_close(
decoded_frames["audio_values"][..., : EXPECTED_DECODER_OUTPUTS_BATCH[model_id][bandwidth].shape[-1]],
EXPECTED_DECODER_OUTPUTS_BATCH[model_id][bandwidth].to(torch_device),
rtol=1e-4,
atol=1e-4,
)
# Compare codec error with expected values
codec_error = compute_rmse(decoded_frames["audio_values"], inputs["input_values"])
torch.testing.assert_close(
codec_error, EXPECTED_CODEC_ERROR_BATCH[model_id][bandwidth], rtol=1e-4, atol=1e-4
)
# make sure forward and enc-dec give same result
input_values_dec = model(inputs["input_values"], inputs["padding_mask"], bandwidth=float(bandwidth))
torch.testing.assert_close(
input_values_dec["audio_values"], decoded_frames["audio_values"], rtol=1e-4, atol=1e-4
)