adapt to ds3.2
This commit is contained in:
@@ -127,21 +127,45 @@ class RMSNorm(CustomOp):
|
||||
return output, residual_out
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
# def forward_hip(
|
||||
# self,
|
||||
# x: torch.Tensor,
|
||||
# residual: Optional[torch.Tensor] = None,
|
||||
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# if not x.is_contiguous():
|
||||
# # NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
# x = x.contiguous()
|
||||
# if residual is not None:
|
||||
# if _vllm_version < Version("0.9"):
|
||||
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
# return x, residual
|
||||
# else:
|
||||
# residual_out = torch.empty_like(x)
|
||||
# output = torch.empty_like(x)
|
||||
# fused_add_rms_norm(
|
||||
# output,
|
||||
# x,
|
||||
# residual_out,
|
||||
# residual,
|
||||
# self.weight.data,
|
||||
# self.variance_epsilon,
|
||||
# )
|
||||
# return output, residual_out
|
||||
# out = torch.empty_like(x)
|
||||
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
# return out
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if not x.is_contiguous():
|
||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
x = x.contiguous()
|
||||
|
||||
if residual is not None:
|
||||
if _vllm_version < Version("0.9"):
|
||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
else:
|
||||
residual_out = torch.empty_like(x)
|
||||
try:
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
@@ -151,10 +175,21 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user