forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
@@ -26,9 +26,12 @@ def vllm__module_executor__layers__linear__UnquantizedLinearMethod__apply(
|
|||||||
beta = 1.0
|
beta = 1.0
|
||||||
residual = residual.view(-1, residual.shape[-1])
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
|
res_shape = x.shape[0:-1] + (layer.weight.shape[0], )
|
||||||
# MLU matmul requires matching dtypes; cast input to weight dtype
|
# MLU matmul requires all tensors to have matching dtypes
|
||||||
if x.dtype != layer.weight.dtype:
|
target_dtype = layer.weight.dtype
|
||||||
x = x.to(layer.weight.dtype)
|
if x.dtype != target_dtype:
|
||||||
|
x = x.to(target_dtype)
|
||||||
|
if residual is not None and residual.dtype != target_dtype:
|
||||||
|
residual = residual.to(target_dtype)
|
||||||
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
|
return mlu_ops.matmul(x.view(-1, x.shape[-1]), layer.weight, bias, residual, 'none', 1.0, beta).view(res_shape)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user