Files
enginex-mr_series-asr/utils/reader.py
2025-08-20 14:29:42 +08:00

139 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import yaml
from typing import (
Tuple,
List, Any
)
from utils.model import AudioItem
import os
import zipfile
import tarfile
import gzip
import shutil
from pathlib import Path
def read_data(dataset_filepath: str) -> Tuple[str, List[AudioItem]]:
"""
读取数据文件,返回语言和文本列表
参数:
dataset_filepath (str): 数据文件路径
返回:
Tuple[str, List[str]]:
- language: 文件中指定的语言字符串
- datas: 文件中除语言行以外的文本列表(每行为一个元素)
"""
try:
# 认为都是压缩包,先解压数据。
data_extract_path = "/tmp/datas"
data_yaml_path = extract_file(dataset_filepath, data_extract_path)
if not data_yaml_path:
raise ValueError(f"未找到数据集data.yaml文件。")
dataset_filepath = str(Path(data_yaml_path).parent.resolve())
except Exception as e:
logging.exception(e)
with open(f"{dataset_filepath}/data.yaml") as f:
datas = yaml.safe_load(f)
language = datas.get("global", {}).get("lang", "zh")
query_data = datas.get("query_data", [])
audios = []
"""
- audio_length: 1.0099999999997635
duration: 1.0099999999997635
file: zh/0.wav
orig_file: ./112801_1/112801_1-631-772.wav
voice:
- answer: 好吧。
end: 1.0099999999997635
start: 0
"""
for item in query_data:
audio = AudioItem.model_validate(item)
audio.absolute_path = f"{dataset_filepath}/{audio.file}"
audios.append(audio)
return (
language,
audios
)
def extract_file(filepath: str, output_dir: str = ".") -> None:
"""
将数据集解压到指定路径返回data.yaml文件的路径
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
data_yaml_path = None
# 硬编码一下 leaderboard_data_samples 数据集没有加拓展名
# if filepath.endswith(".zip") or filepath.endswith("leaderboard_data_samples"):
# with zipfile.ZipFile(filepath, 'r') as zf:
# # 获取所有文件(非目录)
# all_files = [f for f in zf.namelist() if not f.endswith('/')]
#
# if not all_files:
# raise ValueError(f"数据集文件为空。{filepath}")
#
# # 获取公共路径前缀
# parts_list = [Path(f).parts for f in all_files]
# common_parts = os.path.commonprefix(parts_list)
# strip_prefix_len = len(common_parts)
#
# for file in all_files:
# file_parts = Path(file).parts
# relative_parts = file_parts[strip_prefix_len:]
# dest_path = Path(output_dir).joinpath(*relative_parts)
#
# # 创建父目录
# dest_path.parent.mkdir(parents=True, exist_ok=True)
#
# # 解压写入
# with zf.open(file) as source, open(dest_path, "wb") as target:
# shutil.copyfileobj(source, target)
#
# # 检查是否是 data.yaml
# if Path(file).name == "data.yaml":
# data_yaml_path = str(dest_path.resolve())
#
# logging.info(f"数据集解压成功。")
# else:
# raise ValueError(f"暂时不支持的压缩格式。{filepath}")
# TODO 使用的是已有的数据,都不是按照 zip结尾命名的强制按照zip解压
with zipfile.ZipFile(filepath, 'r') as zf:
# 获取所有文件(非目录)
all_files = [f for f in zf.namelist() if not f.endswith('/')]
if not all_files:
raise ValueError(f"数据集文件为空。{filepath}")
# 获取公共路径前缀
parts_list = [Path(f).parts for f in all_files]
common_parts = os.path.commonprefix(parts_list)
strip_prefix_len = len(common_parts)
for file in all_files:
file_parts = Path(file).parts
relative_parts = file_parts[strip_prefix_len:]
dest_path = Path(output_dir).joinpath(*relative_parts)
# 创建父目录
dest_path.parent.mkdir(parents=True, exist_ok=True)
# 解压写入
with zf.open(file) as source, open(dest_path, "wb") as target:
shutil.copyfileobj(source, target)
# 检查是否是 data.yaml
if Path(file).name == "data.yaml":
data_yaml_path = str(dest_path.resolve())
logging.info(f"数据集解压成功。")
return data_yaml_path