feat: use FlashInfer rmsnorm and silu (#907)
This commit is contained in:
29
python/sglang/srt/layers/activation.py
Normal file
29
python/sglang/srt/layers/activation.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flashinfer.activation import silu_and_mul
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
silu_and_mul(x, out)
|
||||
return out
|
||||
Reference in New Issue
Block a user