Fix run time error in dsv3-fp8 model on mi35x (#10104)

Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
kk
2025-09-08 11:45:17 +08:00
committed by GitHub
parent 37d83c6e6d
commit 400d3b97ae

View File

@@ -249,7 +249,11 @@ class DeepseekV2MLP(nn.Module):
if (self.tp_size == 1) and x.shape[0] == 0:
return x
if gemm_output_zero_allocator != None and x.shape[0] <= 256:
if (
gemm_output_zero_allocator is not None
and x.shape[0] <= 256
and self.gate_up_proj.weight.dtype == torch.uint8
):
y = gemm_output_zero_allocator.allocate(
x.shape[0] * self.gate_up_proj.output_size_per_partition
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)