Files
2026-02-04 17:39:32 +08:00

66 lines
2.5 KiB
Python

import torch
import torch_mlu
import time
from pathlib import Path
import csv
import os
import subprocess
from itertools import product
def benchmark_forward(fn, *inputs, repeats=1, **kwinputs):
notify_start = torch.mlu.Event(enable_timing=True)
notify_end = torch.mlu.Event(enable_timing=True)
notify_start.record()
t0 = time.perf_counter()
for _ in range(repeats):
fn(*inputs, **kwinputs)
notify_end.record()
notify_end.synchronize()
total_e2e_time = time.perf_counter() - t0
average_e2e_time = total_e2e_time / repeats * 1e6
total_hardware_time = notify_start.hardware_time(notify_end)
average_hardware_time = total_hardware_time / repeats
return average_hardware_time, average_e2e_time
def save_to_csv(table, file_path, file_name):
file_name_without_ext, _ = os.path.splitext(file_name)
new_file_name = file_name_without_ext + '.csv'
if file_path is None:
file_path = './'
path = Path(file_path)
if path.suffix:
directory = path.parent
filename = path.name
else:
directory = path
filename = new_file_name
if not directory.exists():
directory.mkdir(parents=True, exist_ok=True)
full_path = directory / filename
if not full_path.exists():
full_path.touch()
with open(full_path, mode="w", newline="") as file:
writer = csv.writer(file)
writer.writerows(table)
print(f"output saved at: {full_path}")
def get_band_width(card_id: int = 0):
cmd = "cnmon info -c " + str(card_id) + " | grep 'MEM BandWidth'| cut -d ':' -f2 | cut -d ' ' -f 2"
res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert res.returncode == 0, "Failed to get BandWidth."
bd = int(res.stdout.decode().strip())
return bd
def generate_token_count(num_expert,
total_token_count):
token_count = torch.randint(low=1, high=1024, size=(num_expert, ), dtype=torch.int32).to(dtype=torch.float32)
sum = torch.sum(token_count, dim=-1) * 1.0
token_count *= total_token_count / sum.item()
token_count = token_count.to(dtype=torch.int32)
cusum_token_count = torch.zeros(num_expert + 1, dtype=torch.int32)
end_expert_token_count = torch.ones(num_expert + 1, dtype=torch.int32) * total_token_count
cusum_token_count[1:] = torch.cumsum(token_count, dim=-1)
cusum_token_count = torch.minimum(cusum_token_count, end_expert_token_count)
cusum_token_count[-1] = total_token_count
return cusum_token_count, cusum_token_count[1:] - cusum_token_count[:-1]