Tiny fix ep_gather behavior different in CI (#11130)
This commit is contained in:
@@ -1104,10 +1104,10 @@ def ep_gather(
|
|||||||
input_index: torch.Tensor,
|
input_index: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
):
|
):
|
||||||
BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
|
|
||||||
num_warps = 2
|
num_warps = 2
|
||||||
num_tokens = output_tensor.shape[0]
|
num_tokens = output_tensor.shape[0]
|
||||||
hidden_size = input_tensor.shape[1]
|
hidden_size = input_tensor.shape[1]
|
||||||
|
BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization
|
||||||
assert hidden_size % BLOCK_D == 0
|
assert hidden_size % BLOCK_D == 0
|
||||||
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
||||||
_fwd_kernel_ep_gather[grid](
|
_fwd_kernel_ep_gather[grid](
|
||||||
|
|||||||
Reference in New Issue
Block a user