add qwen3
This commit is contained in:
99
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_ffn.py
Normal file
99
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_ffn.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_mlu
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
act_dict = {
|
||||
"relu": F.relu,
|
||||
"gelu": F.gelu,
|
||||
"silu": F.silu,
|
||||
}
|
||||
|
||||
def ref_ffn(
|
||||
hidden_states,
|
||||
up_fc_weight,
|
||||
up_fc_bias,
|
||||
down_proj_weight,
|
||||
down_proj_bias,
|
||||
gate_up_proj_weight,
|
||||
gate_up_proj_bias,
|
||||
layernorm_weight,
|
||||
layernorm_bias,
|
||||
act_mode):
|
||||
up_output = F.linear(hidden_states, up_fc_weight, bias=up_fc_bias)
|
||||
act_output = act_dict[act_mode](up_output)
|
||||
if not gate_up_proj_weight is None:
|
||||
gate_output = F.linear(hidden_states, gate_up_proj_weight, bias=gate_up_proj_bias)
|
||||
out = F.linear(act_output * gate_output, down_proj_weight, bias=down_proj_bias)
|
||||
else:
|
||||
out = F.linear(act_output, down_proj_weight, bias=down_proj_bias)
|
||||
return out
|
||||
|
||||
BATCH_SIZE = [1]
|
||||
SEQ_LENS = [1, 64, 1024]
|
||||
HIDDEN_SIZE = [16, 24]
|
||||
INTER_SIZE = [32]
|
||||
DTYPES = [torch.half, torch.float]
|
||||
if "3" not in torch.mlu.get_device_name(0):
|
||||
DTYPES = [torch.half, torch.float]
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZE)
|
||||
@pytest.mark.parametrize("inter_size", INTER_SIZE)
|
||||
@pytest.mark.parametrize("act_name", ["relu", "silu"]) # gelu
|
||||
@pytest.mark.parametrize("use_gate", [True])
|
||||
@pytest.mark.parametrize("use_gate_bias", [False])
|
||||
@pytest.mark.parametrize("use_up_bias", [False])
|
||||
@pytest.mark.parametrize("use_down_bias", [False])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", [0])
|
||||
def test_attention_project(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
inter_size: int,
|
||||
act_name: str,
|
||||
use_gate: bool,
|
||||
use_gate_bias: bool,
|
||||
use_up_bias: bool,
|
||||
use_down_bias: bool,
|
||||
dtype: torch.dtype,
|
||||
seed : int
|
||||
) -> None:
|
||||
device_id = "mlu:0"
|
||||
torch.random.manual_seed(seed)
|
||||
torch.mlu.manual_seed(seed)
|
||||
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device_id)
|
||||
up_proj_weight= torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
|
||||
if use_gate:
|
||||
gate_proj_weight = torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
|
||||
else:
|
||||
gate_proj_weight = None
|
||||
down_proj_weight = torch.randn(hidden_size, inter_size, dtype=dtype, device=device_id)
|
||||
|
||||
out = mlu_ops.ffn(hidden_states,
|
||||
up_proj_weight,
|
||||
None,
|
||||
down_proj_weight,
|
||||
None,
|
||||
gate_proj_weight,
|
||||
None,
|
||||
act_name)
|
||||
|
||||
ref_out = ref_ffn(
|
||||
hidden_states,
|
||||
up_proj_weight,
|
||||
None,
|
||||
down_proj_weight,
|
||||
None,
|
||||
gate_proj_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
act_name
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, atol=1e-1, rtol=1e-1)
|
||||
|
||||
Reference in New Issue
Block a user