import torch import torch_mlu import time from pathlib import Path import csv import os import subprocess from itertools import product def benchmark_forward(fn, *inputs, repeats=1, **kwinputs): notify_start = torch.mlu.Event(enable_timing=True) notify_end = torch.mlu.Event(enable_timing=True) notify_start.record() t0 = time.perf_counter() for _ in range(repeats): fn(*inputs, **kwinputs) notify_end.record() notify_end.synchronize() total_e2e_time = time.perf_counter() - t0 average_e2e_time = total_e2e_time / repeats * 1e6 total_hardware_time = notify_start.hardware_time(notify_end) average_hardware_time = total_hardware_time / repeats return average_hardware_time, average_e2e_time def save_to_csv(table, file_path, file_name): file_name_without_ext, _ = os.path.splitext(file_name) new_file_name = file_name_without_ext + '.csv' if file_path is None: file_path = './' path = Path(file_path) if path.suffix: directory = path.parent filename = path.name else: directory = path filename = new_file_name if not directory.exists(): directory.mkdir(parents=True, exist_ok=True) full_path = directory / filename if not full_path.exists(): full_path.touch() with open(full_path, mode="w", newline="") as file: writer = csv.writer(file) writer.writerows(table) print(f"output saved at: {full_path}") def get_band_width(card_id: int = 0): cmd = "cnmon info -c " + str(card_id) + " | grep 'MEM BandWidth'| cut -d ':' -f2 | cut -d ' ' -f 2" res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) assert res.returncode == 0, "Failed to get BandWidth." bd = int(res.stdout.decode().strip()) return bd def generate_token_count(num_expert, total_token_count): token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32) sum = torch.sum(token_count, dim=-1) * 1.0 token_count *= total_token_count / sum.item() token_count = token_count.to(dtype=torch.int32) cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32) end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count cusum_token_count[1:] = torch.cumsum(token_count, dim=-1) cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count) cusum_token_count[-1] = total_token_count return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]