init
This commit is contained in:
0
transformers/tests/models/pop2piano/__init__.py
Normal file
0
transformers/tests/models/pop2piano/__init__.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.testing_utils import (
|
||||
check_json_file_has_correct_format,
|
||||
require_essentia,
|
||||
require_librosa,
|
||||
require_scipy,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.utils.import_utils import (
|
||||
is_essentia_available,
|
||||
is_librosa_available,
|
||||
is_scipy_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
requirements_available = (
|
||||
is_torch_available() and is_essentia_available() and is_scipy_available() and is_librosa_available()
|
||||
)
|
||||
|
||||
if requirements_available:
|
||||
import torch
|
||||
|
||||
from transformers import Pop2PianoFeatureExtractor
|
||||
|
||||
|
||||
class Pop2PianoFeatureExtractionTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
n_bars=2,
|
||||
sample_rate=22050,
|
||||
use_mel=True,
|
||||
padding_value=0,
|
||||
vocab_size_special=4,
|
||||
vocab_size_note=128,
|
||||
vocab_size_velocity=2,
|
||||
vocab_size_time=100,
|
||||
):
|
||||
self.parent = parent
|
||||
self.n_bars = n_bars
|
||||
self.sample_rate = sample_rate
|
||||
self.use_mel = use_mel
|
||||
self.padding_value = padding_value
|
||||
self.vocab_size_special = vocab_size_special
|
||||
self.vocab_size_note = vocab_size_note
|
||||
self.vocab_size_velocity = vocab_size_velocity
|
||||
self.vocab_size_time = vocab_size_time
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"n_bars": self.n_bars,
|
||||
"sample_rate": self.sample_rate,
|
||||
"use_mel": self.use_mel,
|
||||
"padding_value": self.padding_value,
|
||||
"vocab_size_special": self.vocab_size_special,
|
||||
"vocab_size_note": self.vocab_size_note,
|
||||
"vocab_size_velocity": self.vocab_size_velocity,
|
||||
"vocab_size_time": self.vocab_size_time,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_essentia
|
||||
@require_librosa
|
||||
@require_scipy
|
||||
class Pop2PianoFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = Pop2PianoFeatureExtractor if requirements_available else None
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Pop2PianoFeatureExtractionTester(self)
|
||||
|
||||
def test_feat_extract_from_and_save_pretrained(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
mel_1 = feat_extract_first.use_mel
|
||||
mel_2 = feat_extract_second.use_mel
|
||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_to_json_file(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||
feat_extract_first.to_json_file(json_file_path)
|
||||
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
mel_1 = feat_extract_first.use_mel
|
||||
mel_2 = feat_extract_second.use_mel
|
||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
def test_call(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_input = np.zeros([1000000], dtype=np.float32)
|
||||
|
||||
input_features = feature_extractor(speech_input, sampling_rate=16_000, return_tensors="np")
|
||||
self.assertTrue(input_features.input_features.ndim == 3)
|
||||
self.assertEqual(input_features.input_features.shape[-1], 512)
|
||||
|
||||
self.assertTrue(input_features.beatsteps.ndim == 2)
|
||||
self.assertTrue(input_features.extrapolated_beatstep.ndim == 2)
|
||||
|
||||
def test_integration(self):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
speech_samples = ds.sort("id").select([0])["audio"]
|
||||
input_speech = [x["array"] for x in speech_samples][0]
|
||||
sampling_rate = [x["sampling_rate"] for x in speech_samples][0]
|
||||
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
|
||||
input_features = feature_extractor(
|
||||
input_speech, sampling_rate=sampling_rate, return_tensors="pt"
|
||||
).input_features
|
||||
|
||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||
[[-7.1493, -6.8701, -4.3214], [-5.9473, -5.7548, -3.8438], [-6.1324, -5.9018, -4.3778]]
|
||||
)
|
||||
torch.testing.assert_close(input_features[0, :3, :3], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_attention_mask(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
|
||||
speech_input2 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
|
||||
input_features = feature_extractor(
|
||||
[speech_input1, speech_input2],
|
||||
sampling_rate=[44_100, 16_000],
|
||||
return_tensors="np",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
self.assertTrue(hasattr(input_features, "attention_mask"))
|
||||
|
||||
# check shapes
|
||||
self.assertTrue(input_features["attention_mask"].ndim == 2)
|
||||
self.assertEqual(input_features["attention_mask_beatsteps"].shape[0], 2)
|
||||
self.assertEqual(input_features["attention_mask_extrapolated_beatstep"].shape[0], 2)
|
||||
|
||||
# check if they are any values except 0 and 1
|
||||
self.assertTrue(np.max(input_features["attention_mask"]) == 1)
|
||||
self.assertTrue(np.max(input_features["attention_mask_beatsteps"]) == 1)
|
||||
self.assertTrue(np.max(input_features["attention_mask_extrapolated_beatstep"]) == 1)
|
||||
|
||||
self.assertTrue(np.min(input_features["attention_mask"]) == 0)
|
||||
self.assertTrue(np.min(input_features["attention_mask_beatsteps"]) == 0)
|
||||
self.assertTrue(np.min(input_features["attention_mask_extrapolated_beatstep"]) == 0)
|
||||
|
||||
def test_batch_feature(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
|
||||
speech_input2 = np.ones([2_000_000], dtype=np.float32)
|
||||
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
|
||||
|
||||
input_features = feature_extractor(
|
||||
[speech_input1, speech_input2, speech_input3],
|
||||
sampling_rate=[44_100, 16_000, 48_000],
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
self.assertEqual(len(input_features["input_features"].shape), 3)
|
||||
# check shape
|
||||
self.assertEqual(input_features["beatsteps"].shape[0], 3)
|
||||
self.assertEqual(input_features["extrapolated_beatstep"].shape[0], 3)
|
||||
|
||||
def test_batch_feature_np(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
|
||||
speech_input2 = np.ones([2_000_000], dtype=np.float32)
|
||||
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
|
||||
|
||||
input_features = feature_extractor(
|
||||
[speech_input1, speech_input2, speech_input3],
|
||||
sampling_rate=[44_100, 16_000, 48_000],
|
||||
return_tensors="np",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
# check np array or not
|
||||
self.assertEqual(type(input_features["input_features"]), np.ndarray)
|
||||
|
||||
# check shape
|
||||
self.assertEqual(len(input_features["input_features"].shape), 3)
|
||||
|
||||
def test_batch_feature_pt(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
|
||||
speech_input2 = np.ones([2_000_000], dtype=np.float32)
|
||||
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
|
||||
|
||||
input_features = feature_extractor(
|
||||
[speech_input1, speech_input2, speech_input3],
|
||||
sampling_rate=[44_100, 16_000, 48_000],
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
# check pt tensor or not
|
||||
self.assertEqual(type(input_features["input_features"]), torch.Tensor)
|
||||
|
||||
# check shape
|
||||
self.assertEqual(len(input_features["input_features"].shape), 3)
|
||||
|
||||
@unittest.skip(
|
||||
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
|
||||
)
|
||||
def test_padding_accepts_tensors_pt(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
|
||||
)
|
||||
def test_padding_accepts_tensors_tf(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
|
||||
)
|
||||
def test_padding_from_list(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
|
||||
)
|
||||
def test_padding_from_array(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pop2PianoFeatureExtractor does not support truncation")
|
||||
def test_attention_mask_with_truncation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pop2PianoFeatureExtractor does not supports truncation")
|
||||
def test_truncation_from_array(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pop2PianoFeatureExtractor does not supports truncation")
|
||||
def test_truncation_from_list(self):
|
||||
pass
|
||||
739
transformers/tests/models/pop2piano/test_modeling_pop2piano.py
Normal file
739
transformers/tests/models/pop2piano/test_modeling_pop2piano.py
Normal file
@@ -0,0 +1,739 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Pop2Piano model."""
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import Pop2PianoConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.testing_utils import (
|
||||
require_essentia,
|
||||
require_librosa,
|
||||
require_scipy,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_essentia_available, is_librosa_available, is_scipy_available, is_torch_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Pop2PianoForConditionalGeneration
|
||||
|
||||
|
||||
@require_torch
|
||||
class Pop2PianoModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
encoder_seq_length=7,
|
||||
decoder_seq_length=9,
|
||||
# For common tests
|
||||
is_training=False,
|
||||
use_attention_mask=True,
|
||||
use_labels=True,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_factor=0.002,
|
||||
eos_token_id=1,
|
||||
pad_token_id=0,
|
||||
decoder_start_token_id=0,
|
||||
scope=None,
|
||||
decoder_layers=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.encoder_seq_length = encoder_seq_length
|
||||
self.decoder_seq_length = decoder_seq_length
|
||||
# For common tests
|
||||
self.seq_length = self.decoder_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_factor = initializer_factor
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.scope = None
|
||||
self.decoder_layers = decoder_layers
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
decoder_attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
|
||||
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||
|
||||
lm_labels = (
|
||||
ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) if self.use_labels else None
|
||||
)
|
||||
|
||||
return self.get_config(), input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels
|
||||
|
||||
def get_pipeline_config(self):
|
||||
return Pop2PianoConfig(
|
||||
vocab_size=166, # Pop2Piano forces 100 extra tokens
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_decoder_layers=self.decoder_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return Pop2PianoConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_decoder_layers=self.decoder_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_factor=self.initializer_factor,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
)
|
||||
|
||||
def check_prepare_lm_labels_via_shift_left(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = Pop2PianoForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# make sure that lm_labels are correctly padded from the right
|
||||
lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)
|
||||
|
||||
# add causal pad token mask
|
||||
triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
|
||||
lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
|
||||
decoder_input_ids = model._shift_right(lm_labels)
|
||||
|
||||
for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
|
||||
# first item
|
||||
self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
|
||||
if i < decoder_input_ids_slice.shape[-1]:
|
||||
if i < decoder_input_ids.shape[-1] - 1:
|
||||
# items before diagonal
|
||||
self.parent.assertListEqual(
|
||||
decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
|
||||
)
|
||||
# pad items after diagonal
|
||||
if i < decoder_input_ids.shape[-1] - 2:
|
||||
self.parent.assertListEqual(
|
||||
decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
|
||||
)
|
||||
else:
|
||||
# all items after square
|
||||
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = Pop2PianoForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_past = result.past_key_values
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past
|
||||
self.parent.assertEqual(len(decoder_past), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
|
||||
self.parent.assertEqual(len(decoder_past[0]), 4)
|
||||
|
||||
def create_and_check_with_lm_head(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=lm_labels,
|
||||
)
|
||||
self.parent.assertEqual(len(outputs), 4)
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def create_and_check_decoder_model_past(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = Pop2PianoForConditionalGeneration(config=config).get_decoder().to(torch_device).eval()
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_decoder_model_attention_mask_past(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = Pop2PianoForConditionalGeneration(config=config).get_decoder()
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# create attention mask
|
||||
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
half_seq_length = input_ids.shape[-1] // 2
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# change a random masked slice from input_ids
|
||||
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = Pop2PianoForConditionalGeneration(config=config).get_decoder().to(torch_device).eval()
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_model_fp16_forward(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).half().eval()
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[
|
||||
"encoder_last_hidden_state"
|
||||
]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_encoder_decoder_shared_weights(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
for model_class in [Pop2PianoForConditionalGeneration]:
|
||||
torch.manual_seed(0)
|
||||
model = model_class(config=config).to(torch_device).eval()
|
||||
# load state dict copies weights but does not tie them
|
||||
model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tied_config = copy.deepcopy(config)
|
||||
tied_config.tie_encoder_decoder = True
|
||||
tied_model = model_class(config=tied_config).to(torch_device).eval()
|
||||
|
||||
model_result = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
tied_model_result = tied_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# check that models has less parameters
|
||||
self.parent.assertLess(
|
||||
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
|
||||
)
|
||||
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
|
||||
|
||||
# check that outputs are equal
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(
|
||||
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
# check that outputs after saving and loading are equal
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tied_model.save_pretrained(tmpdirname)
|
||||
tied_model = model_class.from_pretrained(tmpdirname)
|
||||
tied_model.to(torch_device)
|
||||
tied_model.eval()
|
||||
|
||||
# check that models has less parameters
|
||||
self.parent.assertLess(
|
||||
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
|
||||
)
|
||||
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
|
||||
|
||||
tied_model_result = tied_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# check that outputs are equal
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(
|
||||
model_result[0][0, :, random_slice_idx],
|
||||
tied_model_result[0][0, :, random_slice_idx],
|
||||
atol=1e-4,
|
||||
)
|
||||
)
|
||||
|
||||
def check_resize_embeddings_pop2piano_v1_1(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
prev_vocab_size = config.vocab_size
|
||||
|
||||
config.tie_word_embeddings = False
|
||||
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
model.resize_token_embeddings(prev_vocab_size - 10)
|
||||
|
||||
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Pop2PianoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Pop2PianoForConditionalGeneration,) if is_torch_available() else ()
|
||||
# Doesn't run generation tests. Has custom generation method with a different interface
|
||||
all_generative_model_classes = ()
|
||||
pipeline_model_mapping = (
|
||||
{"automatic-speech-recognition": Pop2PianoForConditionalGeneration} if is_torch_available() else {}
|
||||
)
|
||||
all_parallelizable_model_classes = ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
test_model_parallel = False
|
||||
is_encoder_decoder = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Pop2PianoModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Pop2PianoConfig, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_shift_right(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_v1_1(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# check that gated gelu feed forward and different word embeddings work
|
||||
config = config_and_inputs[0]
|
||||
config.tie_word_embeddings = False
|
||||
config.feed_forward_proj = "gated-gelu"
|
||||
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
|
||||
|
||||
def test_config_and_model_silu_gated(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config = config_and_inputs[0]
|
||||
config.feed_forward_proj = "gated-silu"
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_with_lm_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_attn_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_3d_attn_mask(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
) = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
attention_mask = ids_tensor(
|
||||
[self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
|
||||
vocab_size=2,
|
||||
)
|
||||
decoder_attention_mask = ids_tensor(
|
||||
[self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
|
||||
vocab_size=2,
|
||||
)
|
||||
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_encoder_decoder_shared_weights(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "Can't do half precision")
|
||||
def test_model_fp16_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
def test_v1_1_resize_embeddings(self):
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
self.model_tester.check_resize_embeddings_pop2piano_v1_1(config)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "sweetcocoa/pop2piano"
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pass_with_input_features(self):
|
||||
input_features = BatchFeature(
|
||||
{
|
||||
"input_features": torch.rand((75, 100, 512)).type(torch.float32),
|
||||
"beatsteps": torch.randint(size=(1, 955), low=0, high=100).type(torch.float32),
|
||||
"extrapolated_beatstep": torch.randint(size=(1, 900), low=0, high=100).type(torch.float32),
|
||||
}
|
||||
)
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||
model_opts = model.generate(input_features=input_features["input_features"], return_dict_in_generate=True)
|
||||
|
||||
self.assertEqual(model_opts.sequences.ndim, 2)
|
||||
|
||||
def test_pass_with_batched_input_features(self):
|
||||
input_features = BatchFeature(
|
||||
{
|
||||
"input_features": torch.rand((220, 70, 512)).type(torch.float32),
|
||||
"beatsteps": torch.randint(size=(5, 955), low=0, high=100).type(torch.float32),
|
||||
"extrapolated_beatstep": torch.randint(size=(5, 900), low=0, high=100).type(torch.float32),
|
||||
"attention_mask": torch.concatenate(
|
||||
[
|
||||
torch.ones([120, 70], dtype=torch.int32),
|
||||
torch.zeros([1, 70], dtype=torch.int32),
|
||||
torch.ones([50, 70], dtype=torch.int32),
|
||||
torch.zeros([1, 70], dtype=torch.int32),
|
||||
torch.ones([47, 70], dtype=torch.int32),
|
||||
torch.zeros([1, 70], dtype=torch.int32),
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
"attention_mask_beatsteps": torch.ones((5, 955)).type(torch.int32),
|
||||
"attention_mask_extrapolated_beatstep": torch.ones((5, 900)).type(torch.int32),
|
||||
}
|
||||
)
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||
model_opts = model.generate(
|
||||
input_features=input_features["input_features"],
|
||||
attention_mask=input_features["attention_mask"],
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertEqual(model_opts.sequences.ndim, 2)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Pop2PianoModelIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_mel_conditioner_integration(self):
|
||||
composer = "composer1"
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||
input_embeds = torch.ones([10, 100, 512])
|
||||
|
||||
composer_value = model.generation_config.composer_to_feature_token[composer]
|
||||
composer_value = torch.tensor(composer_value)
|
||||
composer_value = composer_value.repeat(input_embeds.size(0))
|
||||
outputs = model.mel_conditioner(
|
||||
input_embeds, composer_value, min(model.generation_config.composer_to_feature_token.values())
|
||||
)
|
||||
|
||||
# check shape
|
||||
self.assertEqual(outputs.size(), torch.Size([10, 101, 512]))
|
||||
|
||||
# check values
|
||||
EXPECTED_OUTPUTS = torch.tensor(
|
||||
[[1.0475305318832397, 0.29052114486694336, -0.47778210043907166], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
|
||||
)
|
||||
|
||||
torch.testing.assert_close(outputs[0, :3, :3], EXPECTED_OUTPUTS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
@require_essentia
|
||||
@require_librosa
|
||||
@require_scipy
|
||||
def test_full_model_integration(self):
|
||||
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
|
||||
from transformers import Pop2PianoProcessor
|
||||
|
||||
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
|
||||
sampling_rate = 44_100
|
||||
|
||||
processor = Pop2PianoProcessor.from_pretrained("sweetcocoa/pop2piano")
|
||||
input_features = processor.feature_extractor(
|
||||
speech_input1, sampling_rate=sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||
outputs = model.generate(
|
||||
input_features=input_features["input_features"], return_dict_in_generate=True
|
||||
).sequences
|
||||
|
||||
# check for shapes
|
||||
self.assertEqual(outputs.size(0), 70)
|
||||
|
||||
# check for values
|
||||
self.assertEqual(outputs[0, :2].detach().cpu().numpy().tolist(), [0, 1])
|
||||
|
||||
# This is the test for a real music from K-Pop genre.
|
||||
@slow
|
||||
@require_essentia
|
||||
@require_librosa
|
||||
@require_scipy
|
||||
def test_real_music(self):
|
||||
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
|
||||
from transformers import Pop2PianoFeatureExtractor, Pop2PianoTokenizer
|
||||
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||
model.eval()
|
||||
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
|
||||
tokenizer = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano")
|
||||
ds = load_dataset("sweetcocoa/pop2piano_ci", split="test")
|
||||
|
||||
output_fe = feature_extractor(
|
||||
ds["audio"][0]["array"], sampling_rate=ds["audio"][0]["sampling_rate"], return_tensors="pt"
|
||||
)
|
||||
output_model = model.generate(input_features=output_fe["input_features"], composer="composer1")
|
||||
output_tokenizer = tokenizer.batch_decode(token_ids=output_model, feature_extractor_output=output_fe)
|
||||
pretty_midi_object = output_tokenizer["pretty_midi_objects"][0]
|
||||
|
||||
# Checking if no of notes are same
|
||||
self.assertEqual(len(pretty_midi_object.instruments[0].notes), 59)
|
||||
predicted_timings = []
|
||||
for i in pretty_midi_object.instruments[0].notes:
|
||||
predicted_timings.append(i.start)
|
||||
|
||||
# Checking note start timings(first 6)
|
||||
EXPECTED_START_TIMINGS = [
|
||||
0.4876190423965454,
|
||||
0.7314285635948181,
|
||||
0.9752380847930908,
|
||||
1.4396371841430664,
|
||||
1.6718367338180542,
|
||||
1.904036283493042,
|
||||
]
|
||||
|
||||
np.allclose(EXPECTED_START_TIMINGS, predicted_timings[:6])
|
||||
|
||||
# Checking note end timings(last 6)
|
||||
EXPECTED_END_TIMINGS = [
|
||||
12.341403007507324,
|
||||
12.567797183990479,
|
||||
12.567797183990479,
|
||||
12.567797183990479,
|
||||
12.794191360473633,
|
||||
12.794191360473633,
|
||||
]
|
||||
|
||||
np.allclose(EXPECTED_END_TIMINGS, predicted_timings[-6:])
|
||||
245
transformers/tests/models/pop2piano/test_processing_pop2piano.py
Normal file
245
transformers/tests/models/pop2piano/test_processing_pop2piano.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.testing_utils import (
|
||||
require_essentia,
|
||||
require_librosa,
|
||||
require_pretty_midi,
|
||||
require_scipy,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
from transformers.utils.import_utils import (
|
||||
is_essentia_available,
|
||||
is_librosa_available,
|
||||
is_pretty_midi_available,
|
||||
is_scipy_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
requirements_available = (
|
||||
is_torch_available()
|
||||
and is_essentia_available()
|
||||
and is_scipy_available()
|
||||
and is_librosa_available()
|
||||
and is_pretty_midi_available()
|
||||
)
|
||||
|
||||
if requirements_available:
|
||||
import pretty_midi
|
||||
|
||||
from transformers import (
|
||||
Pop2PianoFeatureExtractor,
|
||||
Pop2PianoForConditionalGeneration,
|
||||
Pop2PianoProcessor,
|
||||
Pop2PianoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
@require_scipy
|
||||
@require_torch
|
||||
@require_librosa
|
||||
@require_essentia
|
||||
@require_pretty_midi
|
||||
class Pop2PianoProcessorTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
|
||||
tokenizer = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano")
|
||||
processor = Pop2PianoProcessor(feature_extractor, tokenizer)
|
||||
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return Pop2PianoTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return Pop2PianoFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
processor = Pop2PianoProcessor(
|
||||
tokenizer=self.get_tokenizer(),
|
||||
feature_extractor=self.get_feature_extractor(),
|
||||
)
|
||||
processor.save_pretrained(tmpdir)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(
|
||||
unk_token="-1",
|
||||
eos_token="1",
|
||||
pad_token="0",
|
||||
bos_token="2",
|
||||
)
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor()
|
||||
|
||||
processor = Pop2PianoProcessor.from_pretrained(
|
||||
tmpdir,
|
||||
unk_token="-1",
|
||||
eos_token="1",
|
||||
pad_token="0",
|
||||
bos_token="2",
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, Pop2PianoTokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, Pop2PianoFeatureExtractor)
|
||||
|
||||
def get_inputs(self):
|
||||
"""get inputs for both feature extractor and tokenizer"""
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
speech_samples = ds.sort("id").select([0])["audio"]
|
||||
input_speech = [x["array"] for x in speech_samples][0]
|
||||
sampling_rate = [x["sampling_rate"] for x in speech_samples][0]
|
||||
|
||||
feature_extractor_outputs = self.get_feature_extractor()(
|
||||
audio=input_speech, sampling_rate=sampling_rate, return_tensors="pt"
|
||||
)
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
|
||||
token_ids = model.generate(input_features=feature_extractor_outputs["input_features"], composer="composer1")
|
||||
dummy_notes = [
|
||||
[
|
||||
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
|
||||
pretty_midi.Note(start=0.673379, end=0.905578, pitch=73, velocity=77),
|
||||
pretty_midi.Note(start=0.905578, end=2.159456, pitch=73, velocity=77),
|
||||
pretty_midi.Note(start=1.114558, end=2.159456, pitch=78, velocity=77),
|
||||
pretty_midi.Note(start=1.323537, end=1.532517, pitch=80, velocity=77),
|
||||
],
|
||||
[
|
||||
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
|
||||
],
|
||||
]
|
||||
|
||||
return input_speech, sampling_rate, token_ids, dummy_notes
|
||||
|
||||
def test_feature_extractor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Pop2PianoProcessor(
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
input_speech, sampling_rate, _, _ = self.get_inputs()
|
||||
|
||||
feature_extractor_outputs = feature_extractor(
|
||||
audio=input_speech, sampling_rate=sampling_rate, return_tensors="np"
|
||||
)
|
||||
processor_outputs = processor(audio=input_speech, sampling_rate=sampling_rate, return_tensors="np")
|
||||
|
||||
for key in feature_extractor_outputs:
|
||||
self.assertTrue(np.allclose(feature_extractor_outputs[key], processor_outputs[key], atol=1e-4))
|
||||
|
||||
def test_processor_batch_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Pop2PianoProcessor(
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
audio, sampling_rate, token_ids, _ = self.get_inputs()
|
||||
feature_extractor_output = feature_extractor(audio=audio, sampling_rate=sampling_rate, return_tensors="pt")
|
||||
|
||||
encoded_processor = processor.batch_decode(
|
||||
token_ids=token_ids,
|
||||
feature_extractor_output=feature_extractor_output,
|
||||
return_midi=True,
|
||||
)
|
||||
|
||||
encoded_tokenizer = tokenizer.batch_decode(
|
||||
token_ids=token_ids,
|
||||
feature_extractor_output=feature_extractor_output,
|
||||
return_midi=True,
|
||||
)
|
||||
# check start timings
|
||||
encoded_processor_start_timings = [token.start for token in encoded_processor["notes"]]
|
||||
encoded_tokenizer_start_timings = [token.start for token in encoded_tokenizer["notes"]]
|
||||
self.assertListEqual(encoded_processor_start_timings, encoded_tokenizer_start_timings)
|
||||
|
||||
# check end timings
|
||||
encoded_processor_end_timings = [token.end for token in encoded_processor["notes"]]
|
||||
encoded_tokenizer_end_timings = [token.end for token in encoded_tokenizer["notes"]]
|
||||
self.assertListEqual(encoded_processor_end_timings, encoded_tokenizer_end_timings)
|
||||
|
||||
# check pitch
|
||||
encoded_processor_pitch = [token.pitch for token in encoded_processor["notes"]]
|
||||
encoded_tokenizer_pitch = [token.pitch for token in encoded_tokenizer["notes"]]
|
||||
self.assertListEqual(encoded_processor_pitch, encoded_tokenizer_pitch)
|
||||
|
||||
# check velocity
|
||||
encoded_processor_velocity = [token.velocity for token in encoded_processor["notes"]]
|
||||
encoded_tokenizer_velocity = [token.velocity for token in encoded_tokenizer["notes"]]
|
||||
self.assertListEqual(encoded_processor_velocity, encoded_tokenizer_velocity)
|
||||
|
||||
def test_tokenizer_call(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Pop2PianoProcessor(
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
_, _, _, notes = self.get_inputs()
|
||||
|
||||
encoded_processor = processor(
|
||||
notes=notes,
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(encoded_processor, BatchEncoding))
|
||||
|
||||
def test_processor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Pop2PianoProcessor(
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
audio, sampling_rate, _, notes = self.get_inputs()
|
||||
|
||||
inputs = processor(
|
||||
audio=audio,
|
||||
sampling_rate=sampling_rate,
|
||||
notes=notes,
|
||||
)
|
||||
|
||||
self.assertListEqual(
|
||||
list(inputs.keys()),
|
||||
["input_features", "beatsteps", "extrapolated_beatstep", "token_ids"],
|
||||
)
|
||||
|
||||
# test if it raises when no input is passed
|
||||
with pytest.raises(ValueError):
|
||||
processor()
|
||||
@@ -0,0 +1,414 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Please note that Pop2PianoTokenizer is too far from our usual tokenizers and thus cannot use the TokenizerTesterMixin class.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.testing_utils import (
|
||||
is_pretty_midi_available,
|
||||
is_torch_available,
|
||||
require_pretty_midi,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
requirements_available = is_torch_available() and is_pretty_midi_available()
|
||||
if requirements_available:
|
||||
import pretty_midi
|
||||
|
||||
from transformers import Pop2PianoTokenizer
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_pretty_midi
|
||||
class Pop2PianoTokenizerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.tokenizer = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano")
|
||||
|
||||
def get_input_notes(self):
|
||||
notes = [
|
||||
[
|
||||
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
|
||||
pretty_midi.Note(start=0.673379, end=0.905578, pitch=73, velocity=77),
|
||||
pretty_midi.Note(start=0.905578, end=2.159456, pitch=73, velocity=77),
|
||||
pretty_midi.Note(start=1.114558, end=2.159456, pitch=78, velocity=77),
|
||||
pretty_midi.Note(start=1.323537, end=1.532517, pitch=80, velocity=77),
|
||||
],
|
||||
[
|
||||
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
|
||||
],
|
||||
]
|
||||
|
||||
return notes
|
||||
|
||||
def test_call(self):
|
||||
notes = self.get_input_notes()
|
||||
|
||||
output = self.tokenizer(
|
||||
notes,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=10,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
# check the output type
|
||||
self.assertTrue(isinstance(output, BatchEncoding))
|
||||
|
||||
# check the values
|
||||
expected_output_token_ids = torch.tensor(
|
||||
[[134, 133, 74, 135, 77, 132, 77, 133, 77, 82], [134, 133, 74, 136, 132, 74, 134, 134, 134, 134]]
|
||||
)
|
||||
expected_output_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])
|
||||
|
||||
torch.testing.assert_close(output["token_ids"], expected_output_token_ids, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(output["attention_mask"], expected_output_attention_mask, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_batch_decode(self):
|
||||
# test batch decode with model, feature-extractor outputs(beatsteps, extrapolated_beatstep)
|
||||
|
||||
# Please note that this test does not test the accuracy of the outputs, instead it is designed to make sure that
|
||||
# the tokenizer's batch_decode can deal with attention_mask in feature-extractor outputs. For the accuracy check
|
||||
# please see the `test_batch_decode_outputs` test.
|
||||
|
||||
model_output = torch.concatenate(
|
||||
[
|
||||
torch.randint(size=[120, 96], low=0, high=70, dtype=torch.long),
|
||||
torch.zeros(size=[1, 96], dtype=torch.long),
|
||||
torch.randint(size=[50, 96], low=0, high=40, dtype=torch.long),
|
||||
torch.zeros(size=[1, 96], dtype=torch.long),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
input_features = BatchFeature(
|
||||
{
|
||||
"beatsteps": torch.ones([2, 955]),
|
||||
"extrapolated_beatstep": torch.ones([2, 1000]),
|
||||
"attention_mask": torch.concatenate(
|
||||
[
|
||||
torch.ones([120, 96], dtype=torch.long),
|
||||
torch.zeros([1, 96], dtype=torch.long),
|
||||
torch.ones([50, 96], dtype=torch.long),
|
||||
torch.zeros([1, 96], dtype=torch.long),
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
"attention_mask_beatsteps": torch.ones([2, 955]),
|
||||
"attention_mask_extrapolated_beatstep": torch.ones([2, 1000]),
|
||||
}
|
||||
)
|
||||
|
||||
output = self.tokenizer.batch_decode(token_ids=model_output, feature_extractor_output=input_features)[
|
||||
"pretty_midi_objects"
|
||||
]
|
||||
|
||||
# check length
|
||||
self.assertTrue(len(output) == 2)
|
||||
|
||||
# check object type
|
||||
self.assertTrue(isinstance(output[0], pretty_midi.pretty_midi.PrettyMIDI))
|
||||
self.assertTrue(isinstance(output[1], pretty_midi.pretty_midi.PrettyMIDI))
|
||||
|
||||
def test_batch_decode_outputs(self):
|
||||
# test batch decode with model, feature-extractor outputs(beatsteps, extrapolated_beatstep)
|
||||
|
||||
# Please note that this test tests the accuracy of the outputs of the tokenizer's `batch_decode` method.
|
||||
|
||||
model_output = torch.tensor(
|
||||
[
|
||||
[134, 133, 74, 135, 77, 82, 84, 136, 132, 74, 77, 82, 84],
|
||||
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
]
|
||||
)
|
||||
input_features = BatchEncoding(
|
||||
{
|
||||
"beatsteps": torch.tensor([[0.0697, 0.1103, 0.1509, 0.1916]]),
|
||||
"extrapolated_beatstep": torch.tensor([[0.0000, 0.0406, 0.0813, 0.1219]]),
|
||||
}
|
||||
)
|
||||
|
||||
output = self.tokenizer.batch_decode(token_ids=model_output, feature_extractor_output=input_features)
|
||||
|
||||
# check outputs
|
||||
self.assertEqual(len(output["notes"]), 4)
|
||||
|
||||
predicted_start_timings, predicted_end_timings = [], []
|
||||
for i in output["notes"]:
|
||||
predicted_start_timings.append(i.start)
|
||||
predicted_end_timings.append(i.end)
|
||||
|
||||
# Checking note start timings
|
||||
expected_start_timings = torch.tensor(
|
||||
[
|
||||
0.069700,
|
||||
0.110300,
|
||||
0.110300,
|
||||
0.110300,
|
||||
]
|
||||
)
|
||||
predicted_start_timings = torch.tensor(predicted_start_timings)
|
||||
|
||||
torch.testing.assert_close(expected_start_timings, predicted_start_timings, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Checking note end timings
|
||||
expected_end_timings = torch.tensor(
|
||||
[
|
||||
0.191600,
|
||||
0.191600,
|
||||
0.191600,
|
||||
0.191600,
|
||||
]
|
||||
)
|
||||
predicted_end_timings = torch.tensor(predicted_end_timings)
|
||||
|
||||
torch.testing.assert_close(expected_end_timings, predicted_end_timings, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_get_vocab(self):
|
||||
vocab_dict = self.tokenizer.get_vocab()
|
||||
self.assertIsInstance(vocab_dict, dict)
|
||||
self.assertGreaterEqual(len(self.tokenizer), len(vocab_dict))
|
||||
|
||||
vocab = [self.tokenizer.convert_ids_to_tokens(i) for i in range(len(self.tokenizer))]
|
||||
self.assertEqual(len(vocab), len(self.tokenizer))
|
||||
|
||||
self.tokenizer.add_tokens(["asdfasdfasdfasdf"])
|
||||
vocab = [self.tokenizer.convert_ids_to_tokens(i) for i in range(len(self.tokenizer))]
|
||||
self.assertEqual(len(vocab), len(self.tokenizer))
|
||||
|
||||
def test_save_and_load_tokenizer(self):
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
sample_notes = self.get_input_notes()
|
||||
|
||||
self.tokenizer.add_tokens(["bim", "bambam"])
|
||||
additional_special_tokens = self.tokenizer.additional_special_tokens
|
||||
additional_special_tokens.append("new_additional_special_token")
|
||||
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
before_token_ids = self.tokenizer(sample_notes)["token_ids"]
|
||||
before_vocab = self.tokenizer.get_vocab()
|
||||
self.tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
after_tokenizer = self.tokenizer.__class__.from_pretrained(tmpdirname)
|
||||
after_token_ids = after_tokenizer(sample_notes)["token_ids"]
|
||||
after_vocab = after_tokenizer.get_vocab()
|
||||
self.assertDictEqual(before_vocab, after_vocab)
|
||||
self.assertListEqual(before_token_ids, after_token_ids)
|
||||
self.assertIn("bim", after_vocab)
|
||||
self.assertIn("bambam", after_vocab)
|
||||
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
|
||||
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
notes = self.get_input_notes()
|
||||
subwords = self.tokenizer(notes)["token_ids"]
|
||||
|
||||
filename = os.path.join(tmpdirname, "tokenizer.bin")
|
||||
with open(filename, "wb") as handle:
|
||||
pickle.dump(self.tokenizer, handle)
|
||||
|
||||
with open(filename, "rb") as handle:
|
||||
tokenizer_new = pickle.load(handle)
|
||||
|
||||
subwords_loaded = tokenizer_new(notes)["token_ids"]
|
||||
|
||||
self.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
def test_padding_side_in_kwargs(self):
|
||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", padding_side="left")
|
||||
self.assertEqual(tokenizer_p.padding_side, "left")
|
||||
|
||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", padding_side="right")
|
||||
self.assertEqual(tokenizer_p.padding_side, "right")
|
||||
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Pop2PianoTokenizer.from_pretrained,
|
||||
"sweetcocoa/pop2piano",
|
||||
padding_side="unauthorized",
|
||||
)
|
||||
|
||||
def test_truncation_side_in_kwargs(self):
|
||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", truncation_side="left")
|
||||
self.assertEqual(tokenizer_p.truncation_side, "left")
|
||||
|
||||
tokenizer_p = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano", truncation_side="right")
|
||||
self.assertEqual(tokenizer_p.truncation_side, "right")
|
||||
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Pop2PianoTokenizer.from_pretrained,
|
||||
"sweetcocoa/pop2piano",
|
||||
truncation_side="unauthorized",
|
||||
)
|
||||
|
||||
def test_right_and_left_padding(self):
|
||||
tokenizer = self.tokenizer
|
||||
notes = self.get_input_notes()
|
||||
notes = notes[0]
|
||||
max_length = 20
|
||||
|
||||
padding_idx = tokenizer.pad_token_id
|
||||
|
||||
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||
tokenizer.padding_side = "right"
|
||||
padded_notes = tokenizer(notes, padding="max_length", max_length=max_length)["token_ids"]
|
||||
padded_notes_length = len(padded_notes)
|
||||
notes_without_padding = tokenizer(notes, padding="do_not_pad")["token_ids"]
|
||||
padding_size = max_length - len(notes_without_padding)
|
||||
|
||||
self.assertEqual(padded_notes_length, max_length)
|
||||
self.assertEqual(notes_without_padding + [padding_idx] * padding_size, padded_notes)
|
||||
|
||||
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||
tokenizer.padding_side = "left"
|
||||
padded_notes = tokenizer(notes, padding="max_length", max_length=max_length)["token_ids"]
|
||||
padded_notes_length = len(padded_notes)
|
||||
notes_without_padding = tokenizer(notes, padding="do_not_pad")["token_ids"]
|
||||
padding_size = max_length - len(notes_without_padding)
|
||||
|
||||
self.assertEqual(padded_notes_length, max_length)
|
||||
self.assertEqual([padding_idx] * padding_size + notes_without_padding, padded_notes)
|
||||
|
||||
# RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding'
|
||||
notes_without_padding = tokenizer(notes)["token_ids"]
|
||||
|
||||
tokenizer.padding_side = "right"
|
||||
padded_notes_right = tokenizer(notes, padding=False)["token_ids"]
|
||||
self.assertEqual(len(padded_notes_right), len(notes_without_padding))
|
||||
self.assertEqual(padded_notes_right, notes_without_padding)
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
padded_notes_left = tokenizer(notes, padding="longest")["token_ids"]
|
||||
self.assertEqual(len(padded_notes_left), len(notes_without_padding))
|
||||
self.assertEqual(padded_notes_left, notes_without_padding)
|
||||
|
||||
tokenizer.padding_side = "right"
|
||||
padded_notes_right = tokenizer(notes, padding="longest")["token_ids"]
|
||||
self.assertEqual(len(padded_notes_right), len(notes_without_padding))
|
||||
self.assertEqual(padded_notes_right, notes_without_padding)
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
padded_notes_left = tokenizer(notes, padding=False)["token_ids"]
|
||||
self.assertEqual(len(padded_notes_left), len(notes_without_padding))
|
||||
self.assertEqual(padded_notes_left, notes_without_padding)
|
||||
|
||||
def test_right_and_left_truncation(self):
|
||||
tokenizer = self.tokenizer
|
||||
notes = self.get_input_notes()
|
||||
notes = notes[0]
|
||||
truncation_size = 3
|
||||
|
||||
# RIGHT TRUNCATION - Check that it correctly truncates when a maximum length is specified along with the truncation flag set to True
|
||||
tokenizer.truncation_side = "right"
|
||||
full_encoded_notes = tokenizer(notes)["token_ids"]
|
||||
full_encoded_notes_length = len(full_encoded_notes)
|
||||
truncated_notes = tokenizer(notes, max_length=full_encoded_notes_length - truncation_size, truncation=True)[
|
||||
"token_ids"
|
||||
]
|
||||
self.assertEqual(full_encoded_notes_length, len(truncated_notes) + truncation_size)
|
||||
self.assertEqual(full_encoded_notes[:-truncation_size], truncated_notes)
|
||||
|
||||
# LEFT TRUNCATION - Check that it correctly truncates when a maximum length is specified along with the truncation flag set to True
|
||||
tokenizer.truncation_side = "left"
|
||||
full_encoded_notes = tokenizer(notes)["token_ids"]
|
||||
full_encoded_notes_length = len(full_encoded_notes)
|
||||
truncated_notes = tokenizer(notes, max_length=full_encoded_notes_length - truncation_size, truncation=True)[
|
||||
"token_ids"
|
||||
]
|
||||
self.assertEqual(full_encoded_notes_length, len(truncated_notes) + truncation_size)
|
||||
self.assertEqual(full_encoded_notes[truncation_size:], truncated_notes)
|
||||
|
||||
# RIGHT & LEFT TRUNCATION - Check that nothing is done for 'longest' and 'no_truncation'
|
||||
tokenizer.truncation_side = "right"
|
||||
truncated_notes_right = tokenizer(notes, truncation=True)["token_ids"]
|
||||
self.assertEqual(full_encoded_notes_length, len(truncated_notes_right))
|
||||
self.assertEqual(full_encoded_notes, truncated_notes_right)
|
||||
|
||||
tokenizer.truncation_side = "left"
|
||||
truncated_notes_left = tokenizer(notes, truncation="longest_first")["token_ids"]
|
||||
self.assertEqual(len(truncated_notes_left), full_encoded_notes_length)
|
||||
self.assertEqual(truncated_notes_left, full_encoded_notes)
|
||||
|
||||
tokenizer.truncation_side = "right"
|
||||
truncated_notes_right = tokenizer(notes, truncation="longest_first")["token_ids"]
|
||||
self.assertEqual(len(truncated_notes_right), full_encoded_notes_length)
|
||||
self.assertEqual(truncated_notes_right, full_encoded_notes)
|
||||
|
||||
tokenizer.truncation_side = "left"
|
||||
truncated_notes_left = tokenizer(notes, truncation=True)["token_ids"]
|
||||
self.assertEqual(len(truncated_notes_left), full_encoded_notes_length)
|
||||
self.assertEqual(truncated_notes_left, full_encoded_notes)
|
||||
|
||||
def test_padding_to_multiple_of(self):
|
||||
notes = self.get_input_notes()
|
||||
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.skipTest(reason="No padding token.")
|
||||
else:
|
||||
normal_tokens = self.tokenizer(notes[0], padding=True, pad_to_multiple_of=8)
|
||||
for key, value in normal_tokens.items():
|
||||
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
|
||||
|
||||
normal_tokens = self.tokenizer(notes[0], pad_to_multiple_of=8)
|
||||
for key, value in normal_tokens.items():
|
||||
self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
|
||||
|
||||
# Should also work with truncation
|
||||
normal_tokens = self.tokenizer(notes[0], padding=True, truncation=True, pad_to_multiple_of=8)
|
||||
for key, value in normal_tokens.items():
|
||||
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
|
||||
|
||||
# truncation to something which is not a multiple of pad_to_multiple_of raises an error
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
self.tokenizer.__call__,
|
||||
notes[0],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=12,
|
||||
pad_to_multiple_of=8,
|
||||
)
|
||||
|
||||
def test_padding_with_attention_mask(self):
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.skipTest(reason="No padding token.")
|
||||
if "attention_mask" not in self.tokenizer.model_input_names:
|
||||
self.skipTest(reason="This model does not use attention mask.")
|
||||
|
||||
features = [
|
||||
{"token_ids": [1, 2, 3, 4, 5, 6], "attention_mask": [1, 1, 1, 1, 1, 0]},
|
||||
{"token_ids": [1, 2, 3], "attention_mask": [1, 1, 0]},
|
||||
]
|
||||
padded_features = self.tokenizer.pad(features)
|
||||
if self.tokenizer.padding_side == "right":
|
||||
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [1, 1, 0, 0, 0, 0]])
|
||||
else:
|
||||
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [0, 0, 0, 1, 1, 0]])
|
||||
Reference in New Issue
Block a user