Reorganize c++ source files in sgl-kernel with multiple folders (#4025)

This commit is contained in:
Lianmin Zheng
2025-03-03 05:32:30 -08:00
committed by GitHub
parent a7000a7650
commit 6b45a21d16
20 changed files with 203 additions and 210 deletions

View File

@@ -1,11 +1,10 @@
import ctypes
import logging
import os
import random
import socket
import time
import unittest
from typing import Any, List, Optional, Union
from typing import Any, List, Optional
import ray
import torch
@@ -115,7 +114,7 @@ class TestCustomAllReduce(unittest.TestCase):
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
)
self.custom_ptr = custom_ops.init_custom_reduce(
@@ -148,7 +147,7 @@ class TestCustomAllReduce(unittest.TestCase):
self.vllm_max_size, group=group
)
self.vllm_rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
)
self.vllm_ptr = vllm_ops.init_custom_ar(
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
@@ -171,8 +170,7 @@ class TestCustomAllReduce(unittest.TestCase):
@staticmethod
def init_distributed_env(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
device = torch.device("cuda:0")
torch.cuda.set_device(device)
ranks = [i for i in range(world_size)]
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
@@ -234,8 +232,8 @@ class TestCustomAllReduce(unittest.TestCase):
if rank == 0:
logger.warning(
f"test_size = {sz}, world_size = {world_size}, "
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms,"
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms"
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, "
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms "
)
self.free_custom_allreduce(group)