Fix DeepSeek DP Attention + torch compile (#5367)
Co-authored-by: ispobock <ispobaoke@163.com>
This commit is contained in:
@@ -192,8 +192,7 @@ def _dp_gather(
|
|||||||
|
|
||||||
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
||||||
assert (
|
assert (
|
||||||
global_tokens.untyped_storage().data_ptr()
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||||
!= local_tokens.untyped_storage().data_ptr()
|
|
||||||
), "aliasing between global_tokens and local_tokens not allowed"
|
), "aliasing between global_tokens and local_tokens not allowed"
|
||||||
memcpy_triton(
|
memcpy_triton(
|
||||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||||
@@ -243,8 +242,7 @@ def dp_scatter(
|
|||||||
assert global_tokens.is_contiguous()
|
assert global_tokens.is_contiguous()
|
||||||
if local_tokens.shape[0] > 0:
|
if local_tokens.shape[0] > 0:
|
||||||
assert (
|
assert (
|
||||||
local_tokens.untyped_storage().data_ptr()
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||||
!= global_tokens.untyped_storage().data_ptr()
|
|
||||||
), "aliasing between local_tokens and global_tokens not allowed"
|
), "aliasing between local_tokens and global_tokens not allowed"
|
||||||
memcpy_triton(
|
memcpy_triton(
|
||||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import json
|
|
||||||
import pandas as pd
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
# Parse command-line arguments
|
# Parse command-line arguments
|
||||||
|
|||||||
@@ -28,6 +28,9 @@ class TestDPAttentionDP2TP2(CustomTestCase):
|
|||||||
"--enable-dp-attention",
|
"--enable-dp-attention",
|
||||||
"--dp",
|
"--dp",
|
||||||
"2",
|
"2",
|
||||||
|
"--enable-torch-compile",
|
||||||
|
"--torch-compile-max-bs",
|
||||||
|
"2",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user