feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -238,6 +238,10 @@ def _dp_gather(
|
||||
assert (
|
||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||
), "aliasing between global_tokens and local_tokens not allowed"
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||
|
||||
memcpy_triton(
|
||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||
)
|
||||
@@ -288,6 +292,10 @@ def dp_scatter(
|
||||
assert (
|
||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||
), "aliasing between local_tokens and global_tokens not allowed"
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user