From defede5073fe39e7a79187a1f85e559a99201ef0 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 14 Apr 2025 16:07:58 +0800 Subject: [PATCH] Fix DeepSeek DP Attention + torch compile (#5367) Co-authored-by: ispobock --- python/sglang/srt/layers/dp_attention.py | 6 ++---- test/srt/parse_results.py | 5 +++-- test/srt/test_dp_attention.py | 3 +++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index bf6064119..c1b9e05ec 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -192,8 +192,7 @@ def _dp_gather( if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): assert ( - global_tokens.untyped_storage().data_ptr() - != local_tokens.untyped_storage().data_ptr() + local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between global_tokens and local_tokens not allowed" memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False @@ -243,8 +242,7 @@ def dp_scatter( assert global_tokens.is_contiguous() if local_tokens.shape[0] > 0: assert ( - local_tokens.untyped_storage().data_ptr() - != global_tokens.untyped_storage().data_ptr() + local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between local_tokens and global_tokens not allowed" memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True diff --git a/test/srt/parse_results.py b/test/srt/parse_results.py index 8389a4b9c..de1d5cf27 100644 --- a/test/srt/parse_results.py +++ b/test/srt/parse_results.py @@ -1,7 +1,8 @@ -import json -import pandas as pd import argparse +import json import os + +import pandas as pd from tabulate import tabulate # Parse command-line arguments diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 5fb5e223b..b47fe2c46 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -28,6 +28,9 @@ class TestDPAttentionDP2TP2(CustomTestCase): "--enable-dp-attention", "--dp", "2", + "--enable-torch-compile", + "--torch-compile-max-bs", + "2", ], )