100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch_mlu
|
|
from vllm import _mlu_ops as mlu_ops
|
|
|
|
act_dict = {
|
|
"relu": F.relu,
|
|
"gelu": F.gelu,
|
|
"silu": F.silu,
|
|
}
|
|
|
|
def ref_ffn(
|
|
hidden_states,
|
|
up_fc_weight,
|
|
up_fc_bias,
|
|
down_proj_weight,
|
|
down_proj_bias,
|
|
gate_up_proj_weight,
|
|
gate_up_proj_bias,
|
|
layernorm_weight,
|
|
layernorm_bias,
|
|
act_mode):
|
|
up_output = F.linear(hidden_states, up_fc_weight, bias=up_fc_bias)
|
|
act_output = act_dict[act_mode](up_output)
|
|
if not gate_up_proj_weight is None:
|
|
gate_output = F.linear(hidden_states, gate_up_proj_weight, bias=gate_up_proj_bias)
|
|
out = F.linear(act_output * gate_output, down_proj_weight, bias=down_proj_bias)
|
|
else:
|
|
out = F.linear(act_output, down_proj_weight, bias=down_proj_bias)
|
|
return out
|
|
|
|
BATCH_SIZE = [1]
|
|
SEQ_LENS = [1, 64, 1024]
|
|
HIDDEN_SIZE = [16, 24]
|
|
INTER_SIZE = [32]
|
|
DTYPES = [torch.half, torch.float]
|
|
if "3" not in torch.mlu.get_device_name(0):
|
|
DTYPES = [torch.half, torch.float]
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZE)
|
|
@pytest.mark.parametrize("inter_size", INTER_SIZE)
|
|
@pytest.mark.parametrize("act_name", ["relu", "silu"]) # gelu
|
|
@pytest.mark.parametrize("use_gate", [True])
|
|
@pytest.mark.parametrize("use_gate_bias", [False])
|
|
@pytest.mark.parametrize("use_up_bias", [False])
|
|
@pytest.mark.parametrize("use_down_bias", [False])
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("seed", [0])
|
|
def test_attention_project(
|
|
batch_size: int,
|
|
seq_len: int,
|
|
hidden_size: int,
|
|
inter_size: int,
|
|
act_name: str,
|
|
use_gate: bool,
|
|
use_gate_bias: bool,
|
|
use_up_bias: bool,
|
|
use_down_bias: bool,
|
|
dtype: torch.dtype,
|
|
seed : int
|
|
) -> None:
|
|
device_id = "mlu:0"
|
|
torch.random.manual_seed(seed)
|
|
torch.mlu.manual_seed(seed)
|
|
|
|
hidden_states = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device_id)
|
|
up_proj_weight= torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
|
|
if use_gate:
|
|
gate_proj_weight = torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
|
|
else:
|
|
gate_proj_weight = None
|
|
down_proj_weight = torch.randn(hidden_size, inter_size, dtype=dtype, device=device_id)
|
|
|
|
out = mlu_ops.ffn(hidden_states,
|
|
up_proj_weight,
|
|
None,
|
|
down_proj_weight,
|
|
None,
|
|
gate_proj_weight,
|
|
None,
|
|
act_name)
|
|
|
|
ref_out = ref_ffn(
|
|
hidden_states,
|
|
up_proj_weight,
|
|
None,
|
|
down_proj_weight,
|
|
None,
|
|
gate_proj_weight,
|
|
None,
|
|
None,
|
|
None,
|
|
act_name
|
|
)
|
|
|
|
assert torch.allclose(out, ref_out, atol=1e-1, rtol=1e-1)
|
|
|