Add tests to AMD CI for MI35x (#9662)

Co-authored-by: Sai Enduri <saimanas.enduri@amd.com>
This commit is contained in:
Hubert Lu
2025-09-10 12:50:05 -07:00
committed by GitHub
parent 9e2f7252db
commit 91b3555d2d
7 changed files with 159 additions and 122 deletions

View File

@@ -2027,7 +2027,10 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_format = (
"mxfp4"
if _is_gfx95_supported
and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
is not None
and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
else ""
)
@@ -2582,7 +2585,11 @@ class DeepseekV2ForCausalLM(nn.Module):
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
if (
_use_aiter_gfx95
and self.quant_config is not None
and self.quant_config.get_name() == "quark"
):
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
quark_post_load_weights(self_attn, w, "mxfp4")
)