Files
enginex-mlu370-vllm/vllm-v0.6.2/tests/kernels/bt_torch_ops/test_ffn.py
2026-02-04 17:22:39 +08:00

100 lines
2.9 KiB
Python

import pytest
import torch
import torch.nn.functional as F
import torch_mlu
from vllm import _mlu_ops as mlu_ops
act_dict = {
"relu": F.relu,
"gelu": F.gelu,
"silu": F.silu,
}
def ref_ffn(
hidden_states,
up_fc_weight,
up_fc_bias,
down_proj_weight,
down_proj_bias,
gate_up_proj_weight,
gate_up_proj_bias,
layernorm_weight,
layernorm_bias,
act_mode):
up_output = F.linear(hidden_states, up_fc_weight, bias=up_fc_bias)
act_output = act_dict[act_mode](up_output)
if not gate_up_proj_weight is None:
gate_output = F.linear(hidden_states, gate_up_proj_weight, bias=gate_up_proj_bias)
out = F.linear(act_output * gate_output, down_proj_weight, bias=down_proj_bias)
else:
out = F.linear(act_output, down_proj_weight, bias=down_proj_bias)
return out
BATCH_SIZE = [1]
SEQ_LENS = [1, 64, 1024]
HIDDEN_SIZE = [16, 24]
INTER_SIZE = [32]
DTYPES = [torch.half, torch.float]
if "3" not in torch.mlu.get_device_name(0):
DTYPES = [torch.half, torch.float]
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZE)
@pytest.mark.parametrize("inter_size", INTER_SIZE)
@pytest.mark.parametrize("act_name", ["relu", "silu"]) # gelu
@pytest.mark.parametrize("use_gate", [True])
@pytest.mark.parametrize("use_gate_bias", [False])
@pytest.mark.parametrize("use_up_bias", [False])
@pytest.mark.parametrize("use_down_bias", [False])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", [0])
def test_attention_project(
batch_size: int,
seq_len: int,
hidden_size: int,
inter_size: int,
act_name: str,
use_gate: bool,
use_gate_bias: bool,
use_up_bias: bool,
use_down_bias: bool,
dtype: torch.dtype,
seed : int
) -> None:
device_id = "mlu:0"
torch.random.manual_seed(seed)
torch.mlu.manual_seed(seed)
hidden_states = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device_id)
up_proj_weight= torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
if use_gate:
gate_proj_weight = torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
else:
gate_proj_weight = None
down_proj_weight = torch.randn(hidden_size, inter_size, dtype=dtype, device=device_id)
out = mlu_ops.ffn(hidden_states,
up_proj_weight,
None,
down_proj_weight,
None,
gate_proj_weight,
None,
act_name)
ref_out = ref_ffn(
hidden_states,
up_proj_weight,
None,
down_proj_weight,
None,
gate_proj_weight,
None,
None,
None,
act_name
)
assert torch.allclose(out, ref_out, atol=1e-1, rtol=1e-1)