139 lines
4.6 KiB
Python
139 lines
4.6 KiB
Python
import logging
|
||
|
||
import yaml
|
||
from typing import (
|
||
Tuple,
|
||
List, Any
|
||
)
|
||
from utils.model import AudioItem
|
||
import os
|
||
import zipfile
|
||
import tarfile
|
||
import gzip
|
||
import shutil
|
||
from pathlib import Path
|
||
|
||
|
||
def read_data(dataset_filepath: str) -> Tuple[str, List[AudioItem]]:
|
||
"""
|
||
读取数据文件,返回语言和文本列表
|
||
|
||
参数:
|
||
dataset_filepath (str): 数据文件路径
|
||
|
||
返回:
|
||
Tuple[str, List[str]]:
|
||
- language: 文件中指定的语言字符串
|
||
- datas: 文件中除语言行以外的文本列表(每行为一个元素)
|
||
"""
|
||
try:
|
||
# 认为都是压缩包,先解压数据。
|
||
data_extract_path = "/tmp/datas"
|
||
data_yaml_path = extract_file(dataset_filepath, data_extract_path)
|
||
if not data_yaml_path:
|
||
raise ValueError(f"未找到数据集data.yaml文件。")
|
||
dataset_filepath = str(Path(data_yaml_path).parent.resolve())
|
||
except Exception as e:
|
||
logging.exception(e)
|
||
|
||
|
||
with open(f"{dataset_filepath}/data.yaml") as f:
|
||
datas = yaml.safe_load(f)
|
||
language = datas.get("global", {}).get("lang", "zh")
|
||
query_data = datas.get("query_data", [])
|
||
|
||
audios = []
|
||
"""
|
||
- audio_length: 1.0099999999997635
|
||
duration: 1.0099999999997635
|
||
file: zh/0.wav
|
||
orig_file: ./112801_1/112801_1-631-772.wav
|
||
voice:
|
||
- answer: 好吧。
|
||
end: 1.0099999999997635
|
||
start: 0
|
||
"""
|
||
for item in query_data:
|
||
audio = AudioItem.model_validate(item)
|
||
audio.absolute_path = f"{dataset_filepath}/{audio.file}"
|
||
audios.append(audio)
|
||
|
||
return (
|
||
language,
|
||
audios
|
||
)
|
||
|
||
|
||
def extract_file(filepath: str, output_dir: str = ".") -> None:
|
||
"""
|
||
将数据集解压到指定路径,返回data.yaml文件的路径
|
||
"""
|
||
if not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
data_yaml_path = None
|
||
# 硬编码一下 leaderboard_data_samples 数据集没有加拓展名
|
||
# if filepath.endswith(".zip") or filepath.endswith("leaderboard_data_samples"):
|
||
# with zipfile.ZipFile(filepath, 'r') as zf:
|
||
# # 获取所有文件(非目录)
|
||
# all_files = [f for f in zf.namelist() if not f.endswith('/')]
|
||
#
|
||
# if not all_files:
|
||
# raise ValueError(f"数据集文件为空。{filepath}")
|
||
#
|
||
# # 获取公共路径前缀
|
||
# parts_list = [Path(f).parts for f in all_files]
|
||
# common_parts = os.path.commonprefix(parts_list)
|
||
# strip_prefix_len = len(common_parts)
|
||
#
|
||
# for file in all_files:
|
||
# file_parts = Path(file).parts
|
||
# relative_parts = file_parts[strip_prefix_len:]
|
||
# dest_path = Path(output_dir).joinpath(*relative_parts)
|
||
#
|
||
# # 创建父目录
|
||
# dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||
#
|
||
# # 解压写入
|
||
# with zf.open(file) as source, open(dest_path, "wb") as target:
|
||
# shutil.copyfileobj(source, target)
|
||
#
|
||
# # 检查是否是 data.yaml
|
||
# if Path(file).name == "data.yaml":
|
||
# data_yaml_path = str(dest_path.resolve())
|
||
#
|
||
# logging.info(f"数据集解压成功。")
|
||
# else:
|
||
# raise ValueError(f"暂时不支持的压缩格式。{filepath}")
|
||
# TODO 使用的是已有的数据,都不是按照 zip结尾命名的,强制按照zip解压
|
||
with zipfile.ZipFile(filepath, 'r') as zf:
|
||
# 获取所有文件(非目录)
|
||
all_files = [f for f in zf.namelist() if not f.endswith('/')]
|
||
|
||
if not all_files:
|
||
raise ValueError(f"数据集文件为空。{filepath}")
|
||
|
||
# 获取公共路径前缀
|
||
parts_list = [Path(f).parts for f in all_files]
|
||
common_parts = os.path.commonprefix(parts_list)
|
||
strip_prefix_len = len(common_parts)
|
||
|
||
for file in all_files:
|
||
file_parts = Path(file).parts
|
||
relative_parts = file_parts[strip_prefix_len:]
|
||
dest_path = Path(output_dir).joinpath(*relative_parts)
|
||
|
||
# 创建父目录
|
||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 解压写入
|
||
with zf.open(file) as source, open(dest_path, "wb") as target:
|
||
shutil.copyfileobj(source, target)
|
||
|
||
# 检查是否是 data.yaml
|
||
if Path(file).name == "data.yaml":
|
||
data_yaml_path = str(dest_path.resolve())
|
||
|
||
logging.info(f"数据集解压成功。")
|
||
|
||
return data_yaml_path
|