Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (#6890)
This commit is contained in:
@@ -227,8 +227,8 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
||||
output_dtype = input.dtype
|
||||
dtype_supported = output_dtype == torch.bfloat16
|
||||
|
||||
# TODO: add more robust shape check here
|
||||
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
||||
# TODO: https://github.com/sgl-project/sglang/pull/6890#issuecomment-2943395737
|
||||
shape_supported = weight.shape[0] % 64 == 0 and weight.shape[1] % 128 == 0
|
||||
|
||||
if not (shape_supported and dtype_supported):
|
||||
# fall back to triton
|
||||
|
||||
Reference in New Issue
Block a user