import logging import yaml from typing import ( Tuple, List, Any ) from utils.model import AudioItem import os import zipfile import tarfile import gzip import shutil from pathlib import Path def read_data(dataset_filepath: str) -> Tuple[str, List[AudioItem]]: """ 读取数据文件,返回语言和文本列表 参数: dataset_filepath (str): 数据文件路径 返回: Tuple[str, List[str]]: - language: 文件中指定的语言字符串 - datas: 文件中除语言行以外的文本列表(每行为一个元素) """ try: # 认为都是压缩包,先解压数据。 data_extract_path = "/tmp/datas" data_yaml_path = extract_file(dataset_filepath, data_extract_path) if not data_yaml_path: raise ValueError(f"未找到数据集data.yaml文件。") dataset_filepath = str(Path(data_yaml_path).parent.resolve()) except Exception as e: logging.exception(e) with open(f"{dataset_filepath}/data.yaml") as f: datas = yaml.safe_load(f) language = datas.get("global", {}).get("lang", "zh") query_data = datas.get("query_data", []) audios = [] """ - audio_length: 1.0099999999997635 duration: 1.0099999999997635 file: zh/0.wav orig_file: ./112801_1/112801_1-631-772.wav voice: - answer: 好吧。 end: 1.0099999999997635 start: 0 """ for item in query_data: audio = AudioItem.model_validate(item) audio.absolute_path = f"{dataset_filepath}/{audio.file}" audios.append(audio) return ( language, audios ) def extract_file(filepath: str, output_dir: str = ".") -> None: """ 将数据集解压到指定路径,返回data.yaml文件的路径 """ if not os.path.exists(output_dir): os.makedirs(output_dir) data_yaml_path = None # 硬编码一下 leaderboard_data_samples 数据集没有加拓展名 # if filepath.endswith(".zip") or filepath.endswith("leaderboard_data_samples"): # with zipfile.ZipFile(filepath, 'r') as zf: # # 获取所有文件(非目录) # all_files = [f for f in zf.namelist() if not f.endswith('/')] # # if not all_files: # raise ValueError(f"数据集文件为空。{filepath}") # # # 获取公共路径前缀 # parts_list = [Path(f).parts for f in all_files] # common_parts = os.path.commonprefix(parts_list) # strip_prefix_len = len(common_parts) # # for file in all_files: # file_parts = Path(file).parts # relative_parts = file_parts[strip_prefix_len:] # dest_path = Path(output_dir).joinpath(*relative_parts) # # # 创建父目录 # dest_path.parent.mkdir(parents=True, exist_ok=True) # # # 解压写入 # with zf.open(file) as source, open(dest_path, "wb") as target: # shutil.copyfileobj(source, target) # # # 检查是否是 data.yaml # if Path(file).name == "data.yaml": # data_yaml_path = str(dest_path.resolve()) # # logging.info(f"数据集解压成功。") # else: # raise ValueError(f"暂时不支持的压缩格式。{filepath}") # TODO 使用的是已有的数据,都不是按照 zip结尾命名的,强制按照zip解压 with zipfile.ZipFile(filepath, 'r') as zf: # 获取所有文件(非目录) all_files = [f for f in zf.namelist() if not f.endswith('/')] if not all_files: raise ValueError(f"数据集文件为空。{filepath}") # 获取公共路径前缀 parts_list = [Path(f).parts for f in all_files] common_parts = os.path.commonprefix(parts_list) strip_prefix_len = len(common_parts) for file in all_files: file_parts = Path(file).parts relative_parts = file_parts[strip_prefix_len:] dest_path = Path(output_dir).joinpath(*relative_parts) # 创建父目录 dest_path.parent.mkdir(parents=True, exist_ok=True) # 解压写入 with zf.open(file) as source, open(dest_path, "wb") as target: shutil.copyfileobj(source, target) # 检查是否是 data.yaml if Path(file).name == "data.yaml": data_yaml_path = str(dest_path.resolve()) logging.info(f"数据集解压成功。") return data_yaml_path