Files
enginex-c_series-asr/utils/reader.py

139 lines
4.6 KiB
Python
Raw Permalink Normal View History

2025-08-28 18:46:56 +08:00
import logging
import yaml
from typing import (
Tuple,
List, Any
)
from utils.model import AudioItem
import os
import zipfile
import tarfile
import gzip
import shutil
from pathlib import Path
def read_data(dataset_filepath: str) -> Tuple[str, List[AudioItem]]:
"""
读取数据文件返回语言和文本列表
参数:
dataset_filepath (str): 数据文件路径
返回:
Tuple[str, List[str]]:
- language: 文件中指定的语言字符串
- datas: 文件中除语言行以外的文本列表每行为一个元素
"""
try:
# 认为都是压缩包,先解压数据。
data_extract_path = "/tmp/datas"
data_yaml_path = extract_file(dataset_filepath, data_extract_path)
if not data_yaml_path:
raise ValueError(f"未找到数据集data.yaml文件。")
dataset_filepath = str(Path(data_yaml_path).parent.resolve())
except Exception as e:
logging.exception(e)
with open(f"{dataset_filepath}/data.yaml") as f:
datas = yaml.safe_load(f)
language = datas.get("global", {}).get("lang", "zh")
query_data = datas.get("query_data", [])
audios = []
"""
- audio_length: 1.0099999999997635
duration: 1.0099999999997635
file: zh/0.wav
orig_file: ./112801_1/112801_1-631-772.wav
voice:
- answer: 好吧
end: 1.0099999999997635
start: 0
"""
for item in query_data:
audio = AudioItem.model_validate(item)
audio.absolute_path = f"{dataset_filepath}/{audio.file}"
audios.append(audio)
return (
language,
audios
)
def extract_file(filepath: str, output_dir: str = ".") -> None:
"""
将数据集解压到指定路径返回data.yaml文件的路径
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
data_yaml_path = None
# 硬编码一下 leaderboard_data_samples 数据集没有加拓展名
# if filepath.endswith(".zip") or filepath.endswith("leaderboard_data_samples"):
# with zipfile.ZipFile(filepath, 'r') as zf:
# # 获取所有文件(非目录)
# all_files = [f for f in zf.namelist() if not f.endswith('/')]
#
# if not all_files:
# raise ValueError(f"数据集文件为空。{filepath}")
#
# # 获取公共路径前缀
# parts_list = [Path(f).parts for f in all_files]
# common_parts = os.path.commonprefix(parts_list)
# strip_prefix_len = len(common_parts)
#
# for file in all_files:
# file_parts = Path(file).parts
# relative_parts = file_parts[strip_prefix_len:]
# dest_path = Path(output_dir).joinpath(*relative_parts)
#
# # 创建父目录
# dest_path.parent.mkdir(parents=True, exist_ok=True)
#
# # 解压写入
# with zf.open(file) as source, open(dest_path, "wb") as target:
# shutil.copyfileobj(source, target)
#
# # 检查是否是 data.yaml
# if Path(file).name == "data.yaml":
# data_yaml_path = str(dest_path.resolve())
#
# logging.info(f"数据集解压成功。")
# else:
# raise ValueError(f"暂时不支持的压缩格式。{filepath}")
# TODO 使用的是已有的数据,都不是按照 zip结尾命名的强制按照zip解压
with zipfile.ZipFile(filepath, 'r') as zf:
# 获取所有文件(非目录)
all_files = [f for f in zf.namelist() if not f.endswith('/')]
if not all_files:
raise ValueError(f"数据集文件为空。{filepath}")
# 获取公共路径前缀
parts_list = [Path(f).parts for f in all_files]
common_parts = os.path.commonprefix(parts_list)
strip_prefix_len = len(common_parts)
for file in all_files:
file_parts = Path(file).parts
relative_parts = file_parts[strip_prefix_len:]
dest_path = Path(output_dir).joinpath(*relative_parts)
# 创建父目录
dest_path.parent.mkdir(parents=True, exist_ok=True)
# 解压写入
with zf.open(file) as source, open(dest_path, "wb") as target:
shutil.copyfileobj(source, target)
# 检查是否是 data.yaml
if Path(file).name == "data.yaml":
data_yaml_path = str(dest_path.resolve())
logging.info(f"数据集解压成功。")
return data_yaml_path