Support twoshot kernel (#2688)

This commit is contained in:
yizhang2077
2025-01-06 00:47:16 +08:00
committed by GitHub
parent ded9fcd09a
commit 3900a94afe
5 changed files with 216 additions and 21 deletions

View File

@@ -55,13 +55,8 @@ class TestCustomAllReduce(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
cls.test_sizes = {
2: [512, 4096, 32768, 262144, 2097152],
4: [512, 4096, 32768, 131072],
6: [512, 4096, 32768, 65536],
8: [512, 4096, 32768, 65536],
}
cls.world_sizes = [2, 4, 6, 8]
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
cls.world_sizes = [2, 4, 8]
@staticmethod
def create_shared_buffer(
@@ -194,7 +189,7 @@ class TestCustomAllReduce(unittest.TestCase):
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
test_loop = 10
for sz in self.test_sizes[world_size]:
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(
@@ -216,7 +211,7 @@ class TestCustomAllReduce(unittest.TestCase):
self.init_vllm_allreduce(rank, group)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
for sz in self.test_sizes[world_size]:
for sz in self.test_sizes:
inp1 = torch.randint(
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
)