[CI] Fix test cases (#2137)
This commit is contained in:
@@ -24,6 +24,8 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
@@ -501,7 +503,7 @@ def _decode_grouped_att_m_fwd(
|
||||
num_warps = 4
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
@@ -557,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
@@ -29,6 +29,8 @@ is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
@@ -311,7 +313,7 @@ def extend_attention_fwd(
|
||||
num_stages = 1
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip():
|
||||
if is_hip_:
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
_fwd_kernel[grid](
|
||||
|
||||
@@ -242,15 +242,17 @@ class ModelRunner:
|
||||
)
|
||||
return get_model(vllm_config=vllm_config)
|
||||
except ImportError:
|
||||
return get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
pass
|
||||
|
||||
return get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
|
||||
def get_model_config_params(self):
|
||||
sig = inspect.signature(VllmModelConfig.__init__)
|
||||
|
||||
Reference in New Issue
Block a user