[Feature, Hardware] Enable DeepseekV3 on AMD GPUs (#2601)
Co-authored-by: root <root@banff-cyxtera-s83-5.amd.com> Co-authored-by: HAI <hixiao@gmail.com> Co-authored-by: Bruce Xue <yigex@xilinx.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd(
|
|||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
|
# [TODO] work around shmem limit on MI3xx
|
||||||
|
if is_hip_ and Lk >= 576:
|
||||||
|
BLOCK = 16
|
||||||
|
|
||||||
if Lk == 576:
|
if Lk == 576:
|
||||||
BLOCK_DMODEL = 512
|
BLOCK_DMODEL = 512
|
||||||
BLOCK_DPE = 64
|
BLOCK_DPE = 64
|
||||||
|
|||||||
@@ -477,9 +477,9 @@ def invoke_fused_moe_kernel(
|
|||||||
|
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
padded_size = padding_size
|
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
if block_shape is None:
|
if block_shape is None:
|
||||||
|
padded_size = padding_size
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
else:
|
else:
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
@@ -614,7 +614,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 4,
|
"num_stages": 2 if is_hip_flag else 4,
|
||||||
}
|
}
|
||||||
if M <= E:
|
if M <= E:
|
||||||
config = {
|
config = {
|
||||||
@@ -623,7 +623,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4,
|
"num_stages": 2 if is_hip_flag else 4,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
||||||
@@ -633,7 +633,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": block_shape[1],
|
"BLOCK_SIZE_K": block_shape[1],
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3,
|
"num_stages": 2 if is_hip_flag else 3,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config = {
|
config = {
|
||||||
@@ -878,7 +878,7 @@ def fused_experts_impl(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if not use_fp8_w8a8:
|
if not use_fp8_w8a8 or block_shape is not None:
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
|
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
|
|||||||
Reference in New Issue
Block a user