Fix run time error in dsv3-fp8 model on mi35x (#10104)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: HaiShaw <hixiao@gmail.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -249,7 +249,11 @@ class DeepseekV2MLP(nn.Module):
|
||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||
return x
|
||||
|
||||
if gemm_output_zero_allocator != None and x.shape[0] <= 256:
|
||||
if (
|
||||
gemm_output_zero_allocator is not None
|
||||
and x.shape[0] <= 256
|
||||
and self.gate_up_proj.weight.dtype == torch.uint8
|
||||
):
|
||||
y = gemm_output_zero_allocator.allocate(
|
||||
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
||||
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
||||
|
||||
Reference in New Issue
Block a user