139 lines
4.6 KiB
Python
139 lines
4.6 KiB
Python
|
|
import logging
|
|||
|
|
|
|||
|
|
import yaml
|
|||
|
|
from typing import (
|
|||
|
|
Tuple,
|
|||
|
|
List, Any
|
|||
|
|
)
|
|||
|
|
from utils.model import AudioItem
|
|||
|
|
import os
|
|||
|
|
import zipfile
|
|||
|
|
import tarfile
|
|||
|
|
import gzip
|
|||
|
|
import shutil
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def read_data(dataset_filepath: str) -> Tuple[str, List[AudioItem]]:
|
|||
|
|
"""
|
|||
|
|
读取数据文件,返回语言和文本列表
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
dataset_filepath (str): 数据文件路径
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
Tuple[str, List[str]]:
|
|||
|
|
- language: 文件中指定的语言字符串
|
|||
|
|
- datas: 文件中除语言行以外的文本列表(每行为一个元素)
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 认为都是压缩包,先解压数据。
|
|||
|
|
data_extract_path = "/tmp/datas"
|
|||
|
|
data_yaml_path = extract_file(dataset_filepath, data_extract_path)
|
|||
|
|
if not data_yaml_path:
|
|||
|
|
raise ValueError(f"未找到数据集data.yaml文件。")
|
|||
|
|
dataset_filepath = str(Path(data_yaml_path).parent.resolve())
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.exception(e)
|
|||
|
|
|
|||
|
|
|
|||
|
|
with open(f"{dataset_filepath}/data.yaml") as f:
|
|||
|
|
datas = yaml.safe_load(f)
|
|||
|
|
language = datas.get("global", {}).get("lang", "zh")
|
|||
|
|
query_data = datas.get("query_data", [])
|
|||
|
|
|
|||
|
|
audios = []
|
|||
|
|
"""
|
|||
|
|
- audio_length: 1.0099999999997635
|
|||
|
|
duration: 1.0099999999997635
|
|||
|
|
file: zh/0.wav
|
|||
|
|
orig_file: ./112801_1/112801_1-631-772.wav
|
|||
|
|
voice:
|
|||
|
|
- answer: 好吧。
|
|||
|
|
end: 1.0099999999997635
|
|||
|
|
start: 0
|
|||
|
|
"""
|
|||
|
|
for item in query_data:
|
|||
|
|
audio = AudioItem.model_validate(item)
|
|||
|
|
audio.absolute_path = f"{dataset_filepath}/{audio.file}"
|
|||
|
|
audios.append(audio)
|
|||
|
|
|
|||
|
|
return (
|
|||
|
|
language,
|
|||
|
|
audios
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_file(filepath: str, output_dir: str = ".") -> None:
|
|||
|
|
"""
|
|||
|
|
将数据集解压到指定路径,返回data.yaml文件的路径
|
|||
|
|
"""
|
|||
|
|
if not os.path.exists(output_dir):
|
|||
|
|
os.makedirs(output_dir)
|
|||
|
|
data_yaml_path = None
|
|||
|
|
# 硬编码一下 leaderboard_data_samples 数据集没有加拓展名
|
|||
|
|
# if filepath.endswith(".zip") or filepath.endswith("leaderboard_data_samples"):
|
|||
|
|
# with zipfile.ZipFile(filepath, 'r') as zf:
|
|||
|
|
# # 获取所有文件(非目录)
|
|||
|
|
# all_files = [f for f in zf.namelist() if not f.endswith('/')]
|
|||
|
|
#
|
|||
|
|
# if not all_files:
|
|||
|
|
# raise ValueError(f"数据集文件为空。{filepath}")
|
|||
|
|
#
|
|||
|
|
# # 获取公共路径前缀
|
|||
|
|
# parts_list = [Path(f).parts for f in all_files]
|
|||
|
|
# common_parts = os.path.commonprefix(parts_list)
|
|||
|
|
# strip_prefix_len = len(common_parts)
|
|||
|
|
#
|
|||
|
|
# for file in all_files:
|
|||
|
|
# file_parts = Path(file).parts
|
|||
|
|
# relative_parts = file_parts[strip_prefix_len:]
|
|||
|
|
# dest_path = Path(output_dir).joinpath(*relative_parts)
|
|||
|
|
#
|
|||
|
|
# # 创建父目录
|
|||
|
|
# dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
#
|
|||
|
|
# # 解压写入
|
|||
|
|
# with zf.open(file) as source, open(dest_path, "wb") as target:
|
|||
|
|
# shutil.copyfileobj(source, target)
|
|||
|
|
#
|
|||
|
|
# # 检查是否是 data.yaml
|
|||
|
|
# if Path(file).name == "data.yaml":
|
|||
|
|
# data_yaml_path = str(dest_path.resolve())
|
|||
|
|
#
|
|||
|
|
# logging.info(f"数据集解压成功。")
|
|||
|
|
# else:
|
|||
|
|
# raise ValueError(f"暂时不支持的压缩格式。{filepath}")
|
|||
|
|
# TODO 使用的是已有的数据,都不是按照 zip结尾命名的,强制按照zip解压
|
|||
|
|
with zipfile.ZipFile(filepath, 'r') as zf:
|
|||
|
|
# 获取所有文件(非目录)
|
|||
|
|
all_files = [f for f in zf.namelist() if not f.endswith('/')]
|
|||
|
|
|
|||
|
|
if not all_files:
|
|||
|
|
raise ValueError(f"数据集文件为空。{filepath}")
|
|||
|
|
|
|||
|
|
# 获取公共路径前缀
|
|||
|
|
parts_list = [Path(f).parts for f in all_files]
|
|||
|
|
common_parts = os.path.commonprefix(parts_list)
|
|||
|
|
strip_prefix_len = len(common_parts)
|
|||
|
|
|
|||
|
|
for file in all_files:
|
|||
|
|
file_parts = Path(file).parts
|
|||
|
|
relative_parts = file_parts[strip_prefix_len:]
|
|||
|
|
dest_path = Path(output_dir).joinpath(*relative_parts)
|
|||
|
|
|
|||
|
|
# 创建父目录
|
|||
|
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 解压写入
|
|||
|
|
with zf.open(file) as source, open(dest_path, "wb") as target:
|
|||
|
|
shutil.copyfileobj(source, target)
|
|||
|
|
|
|||
|
|
# 检查是否是 data.yaml
|
|||
|
|
if Path(file).name == "data.yaml":
|
|||
|
|
data_yaml_path = str(dest_path.resolve())
|
|||
|
|
|
|||
|
|
logging.info(f"数据集解压成功。")
|
|||
|
|
|
|||
|
|
return data_yaml_path
|