This commit is contained in:
2025-10-09 16:47:16 +08:00
parent c8feb4deb5
commit e27e3f16bb
5248 changed files with 1778505 additions and 0 deletions

View File

@@ -0,0 +1,231 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Dia feature extractor."""
import itertools
import random
import unittest
import numpy as np
from transformers import DiaFeatureExtractor
from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_torch_available():
import torch
global_rng = random.Random()
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
class DiaFeatureExtractionTester:
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.__init__
def __init__(
self,
parent,
batch_size=7,
min_seq_length=400,
max_seq_length=2000,
feature_size=1,
padding_value=0.0,
sampling_rate=16000,
hop_length=512,
):
self.parent = parent
self.batch_size = batch_size
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.hop_length = hop_length
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
self.feature_size = feature_size
self.padding_value = padding_value
self.sampling_rate = sampling_rate
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.prepare_feat_extract_dict
def prepare_feat_extract_dict(self):
return {
"feature_size": self.feature_size,
"padding_value": self.padding_value,
"sampling_rate": self.sampling_rate,
"hop_length": self.hop_length,
}
# Copied from tests.models.encodec.test_feature_extraction_encodec.EnCodecFeatureExtractionTester.prepare_inputs_for_common
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))
if equal_length:
audio_inputs = floats_list((self.batch_size, self.max_seq_length))
else:
# make sure that inputs increase in size
audio_inputs = [
_flatten(floats_list((x, self.feature_size)))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
]
if numpify:
audio_inputs = [np.asarray(x) for x in audio_inputs]
return audio_inputs
@require_torch
class DiaFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = DiaFeatureExtractor
def setUp(self):
self.feat_extract_tester = DiaFeatureExtractionTester(self)
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_call
def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
audio_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
# Test not batched input
encoded_sequences_1 = feat_extract(audio_inputs[0], return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_audio_inputs[0], return_tensors="np").input_values
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = feat_extract(audio_inputs, padding=True, return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_audio_inputs, padding=True, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_double_precision_pad
def test_double_precision_pad(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_audio_inputs = np.random.rand(100).astype(np.float64)
py_audio_inputs = np_audio_inputs.tolist()
for inputs in [py_audio_inputs, np_audio_inputs]:
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_values.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest._load_datasamples
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
audio_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in audio_samples]
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_integration with Dac->Dia
def test_integration(self):
# fmt: off
EXPECTED_INPUT_VALUES = torch.tensor(
[ 2.3803711e-03, 2.0751953e-03, 1.9836426e-03, 2.1057129e-03,
1.6174316e-03, 3.0517578e-04, 9.1552734e-05, 3.3569336e-04,
9.7656250e-04, 1.8310547e-03, 2.0141602e-03, 2.1057129e-03,
1.7395020e-03, 4.5776367e-04, -3.9672852e-04, 4.5776367e-04,
1.0070801e-03, 9.1552734e-05, 4.8828125e-04, 1.1596680e-03,
7.3242188e-04, 9.4604492e-04, 1.8005371e-03, 1.8310547e-03,
8.8500977e-04, 4.2724609e-04, 4.8828125e-04, 7.3242188e-04,
1.0986328e-03, 2.1057129e-03]
)
# fmt: on
input_audio = self._load_datasamples(1)
feature_extractor = DiaFeatureExtractor()
input_values = feature_extractor(input_audio, return_tensors="pt")["input_values"]
self.assertEqual(input_values.shape, (1, 1, 93696))
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
audio_input_end = torch.tensor(input_audio[0][-30:], dtype=torch.float32)
torch.testing.assert_close(input_values[0, 0, -46:-16], audio_input_end, rtol=1e-4, atol=1e-4)
def test_integration_stereo(self):
# fmt: off
EXPECTED_INPUT_VALUES = torch.tensor(
[2.3804e-03, 2.0752e-03, 1.9836e-03, 2.1057e-03, 1.6174e-03,
3.0518e-04, 9.1553e-05, 3.3569e-04, 9.7656e-04, 1.8311e-03,
2.0142e-03, 2.1057e-03, 1.7395e-03, 4.5776e-04, -3.9673e-04,
4.5776e-04, 1.0071e-03, 9.1553e-05, 4.8828e-04, 1.1597e-03,
7.3242e-04, 9.4604e-04, 1.8005e-03, 1.8311e-03, 8.8501e-04,
4.2725e-04, 4.8828e-04, 7.3242e-04, 1.0986e-03, 2.1057e-03]
)
# fmt: on
input_audio = self._load_datasamples(1)
input_audio = [np.tile(input_audio[0][None], reps=(2, 1))]
feature_extractor = DiaFeatureExtractor(feature_size=2)
input_values = feature_extractor(input_audio, return_tensors="pt").input_values
self.assertEqual(input_values.shape, (1, 1, 93696))
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_truncation_and_padding with Dac->Dia
def test_truncation_and_padding(self):
input_audio = self._load_datasamples(2)
# would be easier if the stride was like
feature_extractor = DiaFeatureExtractor()
# pad and trunc raise an error ?
with self.assertRaisesRegex(
ValueError,
"^Both padding and truncation were set. Make sure you only set one.$",
):
truncated_outputs = feature_extractor(
input_audio, padding="max_length", truncation=True, return_tensors="pt"
).input_values
# force truncate to max_length
truncated_outputs = feature_extractor(
input_audio, truncation=True, max_length=48000, return_tensors="pt"
).input_values
self.assertEqual(truncated_outputs.shape, (2, 1, 48128))
# pad:
padded_outputs = feature_extractor(input_audio, padding=True, return_tensors="pt").input_values
self.assertEqual(padded_outputs.shape, (2, 1, 93696))
# force pad to max length
truncated_outputs = feature_extractor(
input_audio, padding="max_length", max_length=100000, return_tensors="pt"
).input_values
self.assertEqual(truncated_outputs.shape, (2, 1, 100352))
# force no pad
with self.assertRaisesRegex(
ValueError,
"^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$",
):
truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values
truncated_outputs = feature_extractor(input_audio[0], padding=False, return_tensors="pt").input_values
self.assertEqual(truncated_outputs.shape, (1, 1, 93680))

View File

@@ -0,0 +1,751 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Dia model."""
import copy
import pathlib
import tempfile
import unittest
import pytest
from transformers.models.dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
from transformers.testing_utils import (
cleanup,
is_flaky,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_datasets_available():
from datasets import Audio, load_dataset
if is_torch_available():
import torch
from transformers import (
DiaForConditionalGeneration,
DiaModel,
DiaProcessor,
PretrainedConfig,
PreTrainedModel,
)
from transformers.cache_utils import (
Cache,
StaticCache,
)
from transformers.models.dia.modeling_dia import DiaDecoder, DiaEncoder
if is_torchaudio_available():
import torchaudio
if is_soundfile_available():
import soundfile as sf
@require_torch
class DiaModelTester:
def __init__(
self,
parent,
batch_size=3, # need batch_size != num_hidden_layers
seq_length=7,
max_length=50,
is_training=True,
vocab_size=100,
hidden_size=16,
intermediate_size=37,
num_hidden_layers=2,
num_attention_heads=2,
head_dim=8,
decoder_hidden_size=32, # typically larger than encoder
hidden_act="silu",
eos_token_id=97, # special tokens all occur after eos
pad_token_id=98,
bos_token_id=99,
delay_pattern=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.max_length = max_length
self.is_training = is_training
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.decoder_hidden_size = decoder_hidden_size
self.hidden_act = hidden_act
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
# Set default delay pattern if not provided
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 1, 2]
self.num_channels = len(self.delay_pattern)
def get_config(self):
encoder_config = DiaEncoderConfig(
max_position_embeddings=self.max_length,
num_hidden_layers=self.num_hidden_layers,
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_attention_heads, # same as num_attention_heads for testing
head_dim=self.head_dim,
intermediate_size=self.intermediate_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
)
decoder_config = DiaDecoderConfig(
max_position_embeddings=self.max_length,
num_hidden_layers=self.num_hidden_layers,
hidden_size=self.decoder_hidden_size,
intermediate_size=self.intermediate_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=1, # GQA
head_dim=self.head_dim,
cross_num_attention_heads=self.num_attention_heads,
cross_head_dim=self.head_dim,
cross_num_key_value_heads=1, # GQA
cross_hidden_size=self.hidden_size, # match encoder hidden size
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
num_channels=self.num_channels,
)
config = DiaConfig(
encoder_config=encoder_config,
decoder_config=decoder_config,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
delay_pattern=self.delay_pattern,
)
return config
def prepare_config_and_inputs(self) -> tuple[DiaConfig, dict]:
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = input_ids.ne(self.pad_token_id)
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length, self.num_channels], self.vocab_size)
decoder_attention_mask = decoder_input_ids[..., 0].ne(self.pad_token_id)
config = self.get_config()
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self) -> tuple[DiaConfig, dict]:
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def create_and_check_model_forward(self, config, inputs_dict):
model = DiaModel(config=config).to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
# first forward pass
last_hidden_state = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
self.parent.assertTrue(
last_hidden_state.shape, (self.batch_size, self.seq_length, config.decoder_config.hidden_size)
)
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = DiaModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = DiaEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(
input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"]
)[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 3e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = DiaDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
input_ids=inputs_dict["decoder_input_ids"],
attention_mask=inputs_dict["decoder_attention_mask"],
encoder_hidden_states=encoder_last_hidden_state,
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 3e-3)
@require_torch
class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DiaModel, DiaForConditionalGeneration) if is_torch_available() else ()
# We only allow greedy search / sampling with one sequence; see `skip_non_greedy_generate`
all_generative_model_classes = (DiaForConditionalGeneration,)
# TODO: support new pipeline behavior in tests
pipeline_model_mapping = {}
# pipeline_model_mapping = {"text-to-audio": DiaForConditionalGeneration} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
test_resize_embeddings = False
is_encoder_decoder = True
# Indicates VLMs usually but there are many audio models which are also composite
_is_composite = True
def setUp(self):
self.model_tester = DiaModelTester(self)
# Skipping `has_text_modality` but manually testing down below
self.config_tester = ConfigTester(self, has_text_modality=False, config_class=DiaConfig)
self.skip_non_greedy_generate()
def skip_non_greedy_generate(self):
skippable_tests = [
"test_sample_generate_dict_output", # return sequences > 1
"test_beam",
"test_contrastive",
"test_assisted",
"test_prompt_lookup",
"test_model_parallel_beam_search",
"test_generate_without_input_ids",
"test_generate_with_head_masking",
]
for test in skippable_tests:
if self._testMethodName.startswith(test):
self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.")
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
"""Overridden to account for the 2D flattened structure"""
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
inputs_dict["labels"] = torch.ones(
(
self.model_tester.batch_size * self.model_tester.num_channels,
self.model_tester.seq_length,
),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def test_config(self):
self.config_tester.run_common_tests()
# Manual testing because of composite configs
config = self.model_tester.prepare_config_and_inputs()[0]
self.assertTrue(hasattr(config.encoder_config, "vocab_size"), msg="Encoder `vocab_size` does not exist")
self.assertTrue(hasattr(config.decoder_config, "vocab_size"), msg="Decoder `vocab_size` does not exist")
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
@is_flaky
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
# Overriding shape checks as Dia has different shapes on encoder/decoder using a composite config
# + additional special cases where 3D x 2D meshes confuse the expected shape
def _check_logits(self, batch_size, logits, config):
batch_size *= len(config.delay_pattern) # Account for flattening
vocab_size = config.decoder_config.vocab_size
self.assertIsInstance(logits, tuple)
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = vocab_size - logits[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
def _check_attentions_for_generate(
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (output_length - prompt_length))
use_cache = decoder_past_key_values is not None
has_static_cache = isinstance(decoder_past_key_values, StaticCache)
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
# token(s)
for generated_length, iter_attentions in enumerate(attentions):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
query_length = (
prompt_length + generated_length
if not has_static_cache
else decoder_past_key_values.get_max_cache_shape()
)
expected_shape = (
batch_size,
config.decoder_config.num_attention_heads, # Decoder config
model_input_length,
query_length,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
# Encoder config
encoder_expected_shape = (batch_size, config.encoder_config.num_attention_heads, prompt_length, prompt_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (output_length - prompt_length))
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
# new token(s)
for generated_length, iter_hidden_states in enumerate(hidden_states):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
# check hidden size
# we can have different hidden sizes between encoder and decoder --> check both
expected_shape_encoder = (batch_size, model_input_length, config.encoder_config.hidden_size)
expected_shape_decoder = (batch_size, model_input_length, config.decoder_config.hidden_size)
self.assertTrue(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
== [expected_shape_encoder] * len(iter_hidden_states)
or [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
== [expected_shape_decoder] * len(iter_hidden_states)
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
# Encoder config
encoder_expected_shape = (batch_size, prompt_length, config.encoder_config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
# we need the decoder config here
config = config.decoder_config
# (batch, head, seq_length, head_features)
expected_shape = (
batch_size,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
cache_length,
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads,
)
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[layer.keys.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
self.assertListEqual(
[layer.values.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
def _check_scores(self, batch_size, scores, generated_length, config):
# Special case where Dia keeps score in a 2D mesh of (bsz * channels, vocab)
vocab_size = config.decoder_config.vocab_size
expected_shape = (batch_size * len(config.delay_pattern), vocab_size)
self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), generated_length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def test_sdpa_can_dispatch_composite_models(self):
"""
Overwritten as it relies on hardcoded namings atm - checking for our case here specifically
"""
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
sub_models_supporting_sdpa = [
(module._supports_sdpa or module._supports_attention_backend)
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_sdpa_all_modules = (
all(sub_models_supporting_sdpa)
if len(sub_models_supporting_sdpa) > 0
else (model._supports_sdpa or model._supports_attention_backend)
)
if not supports_sdpa_all_modules:
with self.assertRaises(ValueError):
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
else:
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
for key in model_sdpa.config:
if isinstance(getattr(model_sdpa.config, key), PretrainedConfig):
sub_config = getattr(model_sdpa.config, key)
self.assertTrue(sub_config._attn_implementation == "sdpa")
@pytest.mark.generate
@unittest.skip(reason="Custom processor `DiaEOSDelayPatternLogitsProcessor` forces eos token.")
def test_generate_continue_from_past_key_values(self):
"""Only a small change due to the expected shapes"""
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
# with cache, what is considered a prompt is different in the two cases.
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
model = model_class(config).to(torch_device)
model.eval()
generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[1] # the only real modification in this test
inputs["decoder_input_ids"] = outputs_cached.sequences
if "decoder_attention_mask" in inputs:
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
inputs["decoder_attention_mask"],
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
mode="constant",
value=1,
)
first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
@pytest.mark.generate
def test_prepare_inputs_for_generation_kwargs_forwards(self):
super().test_prepare_inputs_for_generation_kwargs_forwards(encoder_outputs=torch.randn(2, 2, 32))
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
def test_hidden_states_output(self):
pass
@unittest.skip(
reason="Dia has too many mixed embedding types which would cause unintentional side effects, e.g. attempts at tying embeddings"
)
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
def test_torchscript_output_hidden_state(self):
pass
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
def test_torchscript_simple(self):
pass
@unittest.skip(reason="Encoder-Decoder cache can not be initialized.")
def test_multi_gpu_data_parallel_forward(self):
pass
class DiaForConditionalGenerationIntegrationTest(unittest.TestCase):
"""
See https://gist.github.com/vasqu/0e3b06360373a4e612aa3b9a7c09185e for generating the integration tests
NOTE: We add a single `eos` line for the last channel which is skipped in the original Dia
(It doesn't change the behaviour as we cut by the eos token position)
"""
def setUp(self):
# it's a dummy ckpt but should suffice for testing purposes
self.model_checkpoint = "AntonV/Dia-1.6B"
self.sampling_rate = 44100
# prepare audio
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=self.sampling_rate))
audio_sample_1 = librispeech_dummy[-1]["audio"]["array"]
audio_sample_2 = librispeech_dummy[-2]["audio"]["array"]
# 10 and 5 codebooks as prefix - saved as files as we need wav files for the original Dia
dac_chunk_len = 512
self.audio_prompt_1_path = "/tmp/dia_test_sample_1.mp3"
self.audio_prompt_2_path = "/tmp/dia_test_sample_2.mp3"
sf.write(self.audio_prompt_1_path, audio_sample_1[: (dac_chunk_len * 10)], self.sampling_rate)
sf.write(self.audio_prompt_2_path, audio_sample_2[: (dac_chunk_len * 5)], self.sampling_rate)
def tearDown(self):
pathlib.Path(self.audio_prompt_1_path).unlink()
pathlib.Path(self.audio_prompt_2_path).unlink()
cleanup(torch_device, gc_collect=True)
@slow
@require_torch_accelerator
def test_dia_model_integration_generate_tts(self):
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS = torch.tensor([[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 524, 1026, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 674, 967, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 674, 364, 360, 1026, 1026, 1026],
[ 568, 804, 10, 674, 364, 981, 728, 1026, 1026],
[ 568, 804, 10, 674, 364, 981, 741, 550, 1026],
[ 568, 804, 10, 674, 364, 981, 568, 378, 90],
[1024, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 1024, 10, 674, 364, 981, 568, 378, 731],
[1025, 1025, 1024, 674, 364, 981, 568, 378, 731],
[1025, 1025, 1025, 1024, 364, 981, 568, 378, 731],
[1025, 1025, 1025, 1025, 1024, 981, 568, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1024, 568, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]],
[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 698, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 697, 10, 524, 1026, 1026, 1026, 1026, 1026],
[ 592, 288, 476, 649, 967, 1026, 1026, 1026, 1026],
[ 592, 740, 386, 674, 364, 360, 1026, 1026, 1026],
[ 592, 402, 386, 347, 362, 981, 728, 1026, 1026],
[ 592, 402, 721, 728, 327, 981, 741, 550, 1026],
[ 592, 402, 721, 728, 460, 62, 676, 378, 90],
[1024, 402, 721, 728, 837, 595, 195, 982, 784],
[1025, 402, 721, 677, 497, 102, 692, 24, 330],
[1025, 402, 721, 677, 511, 102, 503, 871, 609],
[1025, 402, 721, 677, 511, 96, 801, 871, 894],
[1025, 402, 721, 677, 511, 745, 314, 498, 775],
[1025, 402, 721, 677, 511, 745, 314, 498, 105],
[1025, 402, 721, 677, 511, 745, 314, 861, 889],
[1025, 893, 721, 677, 511, 744, 314, 871, 353],
[1025, 1024, 888, 677, 511, 744, 314, 871, 332],
[1025, 1025, 1024, 518, 511, 744, 314, 871, 366],
[1025, 1025, 1025, 1024, 611, 744, 314, 871, 366],
[1025, 1025, 1025, 1025, 1024, 980, 314, 871, 366],
[1025, 1025, 1025, 1025, 1025, 1024, 45, 124, 366],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 871, 366],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 719],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]])
# fmt: on
torch.testing.assert_close(outputs.cpu(), EXPECTED_OUTPUT_TOKENS)
@slow
@require_torch_accelerator
def test_dia_model_integration_generate_audio_context(self):
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
audio_sample_1 = (
torchaudio.load(self.audio_prompt_1_path, channels_first=True, backend="soundfile")[0].squeeze().numpy()
)
audio_sample_2 = (
torchaudio.load(self.audio_prompt_2_path, channels_first=True, backend="soundfile")[0].squeeze().numpy()
)
audio = [audio_sample_1, audio_sample_2]
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
# dia has right padding while we have left padding (for faster prefill)
# additionally we have new tokens vs dia's max tokens (hence we compare each in the respective settings)
outputs_1 = model.generate(**inputs, max_new_tokens=22, do_sample=False)
outputs_2 = model.generate(**inputs, max_new_tokens=27, do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS_1 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 578, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 494, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 501, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 204, 34, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 254, 915, 863, 1026, 1026, 1026, 1026, 1026],
[ 330, 215, 458, 313, 50, 1026, 1026, 1026, 1026],
[ 330, 615, 529, 216, 801, 237, 1026, 1026, 1026],
[ 330, 580, 563, 233, 337, 37, 1018, 1026, 1026],
[ 330, 567, 530, 753, 607, 179, 954, 242, 1026],
[ 330, 627, 6, 1010, 500, 189, 598, 858, 247],
[1024, 432, 480, 530, 122, 3, 788, 149, 814],
[1025, 875, 826, 458, 98, 540, 181, 122, 608],
[1025, 495, 840, 413, 337, 784, 591, 150, 1017],
[1025, 808, 189, 137, 445, 0, 227, 658, 345],
[1025, 397, 89, 753, 1016, 173, 984, 0, 910],
[1025, 875, 460, 934, 50, 335, 670, 818, 722],
[1025, 875, 460, 762, 119, 372, 503, 858, 584],
[1025, 348, 555, 475, 469, 458, 963, 41, 664],
[1025, 1024, 852, 683, 761, 193, 595, 895, 885],
[1025, 1025, 1024, 135, 761, 902, 163, 623, 385],
[1025, 1025, 1025, 1024, 852, 282, 581, 623, 70],
[1025, 1025, 1025, 1025, 1024, 41, 661, 790, 977],
[1025, 1025, 1025, 1025, 1025, 1024, 580, 401, 464],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 756, 61],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 752],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
EXPECTED_OUTPUT_TOKENS_2 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 619, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 968, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1007, 458, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 35, 266, 68, 1026, 1026, 1026, 1026, 1026],
[ 315, 359, 285, 811, 154, 1026, 1026, 1026, 1026],
[ 315, 906, 407, 297, 785, 649, 1026, 1026, 1026],
[ 315, 249, 678, 868, 899, 257, 950, 1026, 1026],
[ 315, 249, 217, 471, 292, 908, 196, 469, 1026],
[ 315, 249, 825, 771, 839, 802, 633, 590, 531],
[1024, 249, 150, 53, 126, 76, 794, 626, 442],
[1025, 249, 825, 218, 359, 864, 526, 626, 770],
[1025, 249, 150, 137, 530, 845, 877, 600, 111],
[1025, 249, 150, 287, 730, 991, 135, 259, 39],
[1025, 249, 825, 104, 198, 1020, 719, 625, 208],
[1025, 249, 825, 997, 602, 256, 859, 322, 518],
[1025, 668, 825, 979, 584, 256, 98, 665, 589],
[1025, 954, 458, 54, 206, 52, 244, 822, 599],
[1025, 1024, 104, 914, 435, 579, 860, 92, 661],
[1025, 1025, 1024, 848, 126, 74, 304, 92, 753],
[1025, 1025, 1025, 1024, 362, 376, 304, 586, 753],
[1025, 1025, 1025, 1025, 1024, 633, 996, 586, 83],
[1025, 1025, 1025, 1025, 1025, 1024, 179, 898, 928],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 506, 102],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 79],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
# fmt: on
torch.testing.assert_close(outputs_1[0].cpu(), EXPECTED_OUTPUT_TOKENS_1)
torch.testing.assert_close(outputs_2[1, 5:].cpu(), EXPECTED_OUTPUT_TOKENS_2) # left padding

View File

@@ -0,0 +1,260 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from transformers import DacModel, DiaFeatureExtractor, DiaProcessor, DiaTokenizer
from transformers.testing_utils import require_torch
from transformers.utils import is_torch_available
if is_torch_available:
import torch
# Copied from tests.utils.test_modeling_utils.check_models_equal
def check_models_equal(model1, model2):
models_are_equal = True
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
if model1_p.data.ne(model2_p.data).sum() > 0:
models_are_equal = False
return models_are_equal
@require_torch
class DiaProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "AntonV/Dia-1.6B"
self.audio_tokenizer_checkpoint = "descript/dac_44khz"
self.tmpdirname = tempfile.mkdtemp()
# Audio tokenizer is a bigger model so we will reuse this if possible
self.processor = DiaProcessor(
tokenizer=self.get_tokenizer(),
feature_extractor=self.get_feature_extractor(),
audio_tokenizer=self.get_audio_tokenizer(),
)
# Default audio values based on Dia and Dac
self.pad_id = 1025
self.bos_id = 1026
self.dac_chunk_len = 512
self.delay_pattern = [0, 8, 9, 10, 11, 12, 13, 14, 15]
def get_tokenizer(self, **kwargs):
return DiaTokenizer.from_pretrained(self.checkpoint, **kwargs)
def get_feature_extractor(self, **kwargs):
return DiaFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
def get_audio_tokenizer(self, **kwargs):
return DacModel.from_pretrained(self.audio_tokenizer_checkpoint, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
del self.processor
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
audio_tokenizer = self.get_audio_tokenizer()
processor = DiaProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer
)
processor.save_pretrained(self.tmpdirname)
processor = DiaProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer.__class__.__name__)
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer.name_or_path)
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer))
self.assertIsInstance(processor.audio_tokenizer, DacModel)
def test_save_load_pretrained_additional_features(self):
processor = DiaProcessor(
tokenizer=self.get_tokenizer(),
feature_extractor=self.get_feature_extractor(),
audio_tokenizer=self.get_audio_tokenizer(),
)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer()
feature_extractor_add_kwargs = self.get_feature_extractor()
audio_tokenizer_add_kwargs = self.get_audio_tokenizer()
processor = DiaProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer_add_kwargs.__class__.__name__)
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer_add_kwargs.name_or_path)
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer_add_kwargs))
self.assertIsInstance(processor.audio_tokenizer, DacModel)
def test_tokenize(self):
tokenizer = self.get_tokenizer()
random_text = ["This is a processing test for tokenization", "[S1] Dia template style [S2] Nice"]
input_tokenizer = tokenizer(random_text, padding=True, return_tensors="pt")
input_processor = self.processor(random_text)
for key in input_tokenizer:
self.assertTrue((input_tokenizer[key] == input_processor[key]).all())
def test_no_audio(self):
random_text = ["Dummy Input"] * 2
input_processor = self.processor(random_text)
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
# full mask with +1 for bos
self.assertTrue(audio_mask.sum() == (max(self.delay_pattern) + 1) * len(random_text))
self.assertTrue(
audio_tokens.shape
== (
len(random_text),
max(self.delay_pattern) + 1,
len(self.delay_pattern),
)
)
for channel_idx, delay in enumerate(self.delay_pattern):
expected_sequence = torch.ones(size=(audio_tokens.shape[:-1])) * self.pad_id
expected_sequence[:, : delay + 1] = self.bos_id
self.assertTrue((audio_tokens[..., channel_idx] == expected_sequence).all())
def test_audio(self):
audio_tokenizer = self.get_audio_tokenizer()
feature_extractor = self.get_feature_extractor()
random_text = ["Dummy Input"] * 2
# Dac only starts accepting audio from a certain length (ensured via >=1024)
raw_speeches = [np.random.rand(2048).astype(np.float32), np.random.rand(1024).astype(np.float32)]
input_processor = self.processor(random_text, raw_speeches)
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
sequence_len = audio_mask.shape[1]
for batch_idx, speech in enumerate(raw_speeches):
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
codebooks = audio_tokenizer(raw_audio).audio_codes.transpose(1, 2)
pad_len = sequence_len - audio_mask.sum(dim=-1)[batch_idx]
for channel_idx, delay in enumerate(self.delay_pattern):
# Left padding filled bos, right padding (delay) are pad
start_idx = pad_len + delay + 1
end_idx = start_idx + codebooks.shape[1]
encoded_sequence = audio_tokens[batch_idx, :, channel_idx]
expected_sequence = torch.ones(size=(sequence_len,)) * self.pad_id
expected_sequence[:start_idx] = self.bos_id
expected_sequence[start_idx:end_idx] = codebooks[0, :, channel_idx]
self.assertTrue((encoded_sequence == expected_sequence).all())
# Just to make sure the masking correctly only ignores bos tokens
self.assertTrue((audio_tokens[~audio_mask.bool()] == self.bos_id).all())
@parameterized.expand([([1, 1],), ([1, 5],), ([2, 4, 6],)])
def test_decode_audio(self, audio_lens):
feature_extractor = self.get_feature_extractor()
audio_tokenizer = self.get_audio_tokenizer()
random_text = ["Dummy Input"] * len(audio_lens)
raw_speeches = [np.random.rand(self.dac_chunk_len * l).astype(np.float32) for l in audio_lens]
# we need eos (given if training) to decode properly, also enforced via custom logits processor
input_processor = self.processor(random_text, raw_speeches, generation=False)
audio_tokens = input_processor["decoder_input_ids"]
decoded_speeches = self.processor.batch_decode(audio_tokens)
for batch_idx, speech in enumerate(raw_speeches):
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
codebooks = audio_tokenizer(raw_audio).audio_codes
decoded_audio = decoded_speeches[batch_idx]
expected_audio = audio_tokenizer.decode(audio_codes=codebooks).audio_values
self.assertTrue((expected_audio == decoded_audio).all())
self.assertTrue(decoded_speeches[batch_idx].shape[-1] == audio_lens[batch_idx] * self.dac_chunk_len)
@parameterized.expand([(1, 2, [0, 1, 4]), (2, 4, [1, 3, 2]), (4, 8, [0, 5, 7])])
def test_delay_in_audio(self, bsz, seq_len, delay_pattern):
# static functions which are crucial, hence we also test them here
build_indices_fn = DiaProcessor.build_indices
delay_fn = DiaProcessor.apply_audio_delay
bos, pad = -2, -1
num_channels = len(delay_pattern)
audio_input = torch.arange(bsz * seq_len * num_channels).view(bsz, seq_len, num_channels)
# imitate a delay mask with zeroes
audio_input = torch.cat([audio_input, torch.zeros(size=(bsz, max(delay_pattern), num_channels))], dim=1)
precomputed_idx = build_indices_fn(
bsz=bsz,
seq_len=seq_len + max(delay_pattern),
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=False,
)
delayed_audio_out = delay_fn(
audio=audio_input,
pad_token_id=pad,
bos_token_id=bos,
precomputed_idx=precomputed_idx,
)
# every channel idx is shifted by delay_pattern[idx]
delayed_audio_res = audio_input.clone()
for idx, delay in enumerate(delay_pattern):
delayed_audio_res[:, :delay, idx] = bos
remaining_input = seq_len + max(delay_pattern) - delay
delayed_audio_res[:, delay:, idx] = audio_input[:, :remaining_input, idx]
self.assertTrue((delayed_audio_out == delayed_audio_res).all())
# we should get back to the original audio we had (when removing the delay pad)
bsz, new_seq_len, num_channels = delayed_audio_out.shape
precomputed_idx = build_indices_fn(
bsz=bsz,
seq_len=new_seq_len,
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=True,
)
reverted_audio_out = delay_fn(
audio=delayed_audio_out,
pad_token_id=pad,
bos_token_id=bos,
precomputed_idx=precomputed_idx,
)
reverted_audio_res = audio_input.clone()[:, :seq_len]
self.assertTrue((reverted_audio_out[:, :seq_len] == reverted_audio_res).all())

View File

@@ -0,0 +1,123 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.models.dia import DiaTokenizer
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
# Special tokens
PAD = 0
S1 = 1
S2 = 2
class DiaTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = DiaTokenizer
test_rust_tokenizer = False
@classmethod
def setUpClass(cls):
super().setUpClass()
tokenizer = DiaTokenizer()
tokenizer.save_pretrained(cls.tmpdirname)
def test_convert_token_and_id(self):
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
token = "i"
token_id = 105
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[PAD], "<pad>")
self.assertEqual(vocab_keys[S1], "[S1]")
self.assertEqual(vocab_keys[S2], "[S2]")
self.assertEqual(len(vocab_keys), 256)
def test_vocab_size(self):
# utf-8 == 2**8 == 256
self.assertEqual(self.get_tokenizer().vocab_size, 256)
def test_full_tokenizer(self):
tokenizer = DiaTokenizer.from_pretrained(self.tmpdirname)
tokens = tokenizer.tokenize("Hello, world!")
self.assertListEqual(tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
tokens = tokenizer.tokenize("[S1] Hello [S2] Hello<pad>")
self.assertListEqual(
tokens,
["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", "<pad>"],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [S1, 32, 72, 101, 108, 108, 111, 32, S2, 32, 72, 101, 108, 108, 111, PAD])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens, ["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", "<pad>"]
)
@slow
def test_tokenizer_integration(self):
# Overwritten as decoding will lead to all single bytes (i.e. characters) while usually the string format is expected
expected_encoding = {'input_ids': [[84, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 40, 102, 111, 114, 109, 101, 114, 108, 121, 32, 107, 110, 111, 119, 110, 32, 97, 115, 32, 112, 121, 116, 111, 114, 99, 104, 45, 116, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 97, 110, 100, 32, 112, 121, 116, 111, 114, 99, 104, 45, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 45, 98, 101, 114, 116, 41, 32, 112, 114, 111, 118, 105, 100, 101, 115, 32, 103, 101, 110, 101, 114, 97, 108, 45, 112, 117, 114, 112, 111, 115, 101, 32, 97, 114, 99, 104, 105, 116, 101, 99, 116, 117, 114, 101, 115, 32, 40, 66, 69, 82, 84, 44, 32, 71, 80, 84, 45, 50, 44, 32, 82, 111, 66, 69, 82, 84, 97, 44, 32, 88, 76, 77, 44, 32, 68, 105, 115, 116, 105, 108, 66, 101, 114, 116, 44, 32, 88, 76, 78, 101, 116, 46, 46, 46, 41, 32, 102, 111, 114, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 85, 110, 100, 101, 114, 115, 116, 97, 110, 100, 105, 110, 103, 32, 40, 78, 76, 85, 41, 32, 97, 110, 100, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 71, 101, 110, 101, 114, 97, 116, 105, 111, 110, 32, 40, 78, 76, 71, 41, 32, 119, 105, 116, 104, 32, 111, 118, 101, 114, 32, 51, 50, 43, 32, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 32, 109, 111, 100, 101, 108, 115, 32, 105, 110, 32, 49, 48, 48, 43, 32, 108, 97, 110, 103, 117, 97, 103, 101, 115, 32, 97, 110, 100, 32, 100, 101, 101, 112, 32, 105, 110, 116, 101, 114, 111, 112, 101, 114, 97, 98, 105, 108, 105, 116, 121, 32, 98, 101, 116, 119, 101, 101, 110, 32, 74, 97, 120, 44, 32, 80, 121, 84, 111, 114, 99, 104, 32, 97, 110, 100, 32, 84, 101, 110, 115, 111, 114, 70, 108, 111, 119, 46], [66, 69, 82, 84, 32, 105, 115, 32, 100, 101, 115, 105, 103, 110, 101, 100, 32, 116, 111, 32, 112, 114, 101, 45, 116, 114, 97, 105, 110, 32, 100, 101, 101, 112, 32, 98, 105, 100, 105, 114, 101, 99, 116, 105, 111, 110, 97, 108, 32, 114, 101, 112, 114, 101, 115, 101, 110, 116, 97, 116, 105, 111, 110, 115, 32, 102, 114, 111, 109, 32, 117, 110, 108, 97, 98, 101, 108, 101, 100, 32, 116, 101, 120, 116, 32, 98, 121, 32, 106, 111, 105, 110, 116, 108, 121, 32, 99, 111, 110, 100, 105, 116, 105, 111, 110, 105, 110, 103, 32, 111, 110, 32, 98, 111, 116, 104, 32, 108, 101, 102, 116, 32, 97, 110, 100, 32, 114, 105, 103, 104, 116, 32, 99, 111, 110, 116, 101, 120, 116, 32, 105, 110, 32, 97, 108, 108, 32, 108, 97, 121, 101, 114, 115, 46], [84, 104, 101, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 32, 106, 117, 109, 112, 115, 32, 111, 118, 101, 114, 32, 116, 104, 101, 32, 108, 97, 122, 121, 32, 100, 111, 103, 46]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
sequences = [
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
"general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
"Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained "
"models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
"conditioning on both left and right context in all layers.",
"The quick brown fox jumps over the lazy dog.",
]
tokenizer_classes = [self.tokenizer_class]
if self.test_rust_tokenizer:
tokenizer_classes.append(self.rust_tokenizer_class)
for tokenizer_class in tokenizer_classes:
tokenizer = tokenizer_class.from_pretrained("AntonV/Dia-1.6B")
encoding = tokenizer(sequences)
encoding_data = encoding.data
self.assertDictEqual(encoding_data, expected_encoding)
# Byte decoding leads to characters so we need to join them
decoded_sequences = [
"".join(tokenizer.decode(seq, skip_special_tokens=True)) for seq in encoding["input_ids"]
]
for expected, decoded in zip(sequences, decoded_sequences):
if self.test_sentencepiece_ignore_case:
expected = expected.lower()
self.assertEqual(expected, decoded)
@unittest.skip(reason="Dia relies on whole input string due to the byte-level nature.")
def test_pretokenized_inputs(self):
pass
@unittest.skip
def test_tokenizer_slow_store_full_signature(self):
pass