[ROCm] Optimal MOE Tuning for AMD Radeon Graphics (#3567)
This commit is contained in:
@@ -175,7 +175,7 @@ def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
|
|||||||
for block_m in [32, 64, 128, 256]:
|
for block_m in [32, 64, 128, 256]:
|
||||||
for block_k in [32, 64, 128, 256]:
|
for block_k in [32, 64, 128, 256]:
|
||||||
for block_n in [16, 32, 64, 128, 256]:
|
for block_n in [16, 32, 64, 128, 256]:
|
||||||
for num_warps in [4, 8]:
|
for num_warps in [1, 2, 4, 8]:
|
||||||
for group_size in [1, 4, 8, 16, 32]:
|
for group_size in [1, 4, 8, 16, 32]:
|
||||||
configs.append(
|
configs.append(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,46 +1,46 @@
|
|||||||
{
|
{
|
||||||
"1": {
|
"1": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 32,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 16,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"2": {
|
"2": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 8,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"4": {
|
"4": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 16,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"8": {
|
"8": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"16": {
|
"16": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
@@ -48,17 +48,17 @@
|
|||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 8,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"32": {
|
"32": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 8,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
@@ -66,8 +66,8 @@
|
|||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
@@ -75,8 +75,8 @@
|
|||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
@@ -84,77 +84,77 @@
|
|||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 8,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"128": {
|
"128": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 16,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"256": {
|
"256": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 16,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"512": {
|
"512": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"1024": {
|
"1024": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 8,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"1536": {
|
"1536": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 8,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"2048": {
|
"2048": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 2,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"3072": {
|
"3072": {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 128,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2,
|
"num_stages": 2,
|
||||||
"waves_per_eu": 0
|
"waves_per_eu": 0
|
||||||
},
|
},
|
||||||
"4096": {
|
"4096": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 4,
|
"GROUP_SIZE_M": 4,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
|
|||||||
Reference in New Issue
Block a user