[Misc] clean up vllm in sgl-kernel test (#5189)
This commit is contained in:
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
@@ -58,12 +57,6 @@ def awq_dequantize_torch(
|
||||
return (iweights - zeros) * scales
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
def sglang_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
@@ -110,7 +103,6 @@ def test_awq_dequant_compare_implementations(
|
||||
)
|
||||
|
||||
# Run both implementations
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales.to(torch.float16), qzeros)
|
||||
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
@@ -118,13 +110,6 @@ def test_awq_dequant_compare_implementations(
|
||||
torch.testing.assert_close(
|
||||
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
)
|
||||
if not is_bf16_act:
|
||||
torch.testing.assert_close(
|
||||
vllm_out.to(torch.float32),
|
||||
sglang_out.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user