Support twoshot kernel (#2688)
This commit is contained in:
@@ -55,13 +55,8 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
random.seed(42)
|
||||
cls.test_sizes = {
|
||||
2: [512, 4096, 32768, 262144, 2097152],
|
||||
4: [512, 4096, 32768, 131072],
|
||||
6: [512, 4096, 32768, 65536],
|
||||
8: [512, 4096, 32768, 65536],
|
||||
}
|
||||
cls.world_sizes = [2, 4, 6, 8]
|
||||
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
|
||||
cls.world_sizes = [2, 4, 8]
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
@@ -194,7 +189,7 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
|
||||
|
||||
test_loop = 10
|
||||
for sz in self.test_sizes[world_size]:
|
||||
for sz in self.test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(
|
||||
@@ -216,7 +211,7 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
self.init_vllm_allreduce(rank, group)
|
||||
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
|
||||
|
||||
for sz in self.test_sizes[world_size]:
|
||||
for sz in self.test_sizes:
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user