2
.github/workflows/pr-test-amd.yml
vendored
2
.github/workflows/pr-test-amd.yml
vendored
@@ -45,6 +45,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
docker exec ci_sglang pip install --upgrade pip
|
docker exec ci_sglang pip install --upgrade pip
|
||||||
|
docker exec ci_sglang pip uninstall sgl-kernel -y || true
|
||||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
|
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
|
||||||
docker exec ci_sglang pip install -e "python[dev_hip]"
|
docker exec ci_sglang pip install -e "python[dev_hip]"
|
||||||
|
|
||||||
@@ -83,6 +84,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
docker exec ci_sglang pip install --upgrade pip
|
docker exec ci_sglang pip install --upgrade pip
|
||||||
|
docker exec ci_sglang pip uninstall sgl-kernel -y || true
|
||||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
|
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
|
||||||
docker exec ci_sglang pip install -e "python[dev_hip]"
|
docker exec ci_sglang pip install -e "python[dev_hip]"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
@@ -32,6 +31,8 @@ _is_cuda = is_cuda()
|
|||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||||
|
else:
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
@@ -46,6 +45,8 @@ if _is_cuda:
|
|||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
if _is_cuda or _is_hip:
|
if _is_cuda or _is_hip:
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
@@ -456,38 +457,8 @@ def moe_align_block_size(
|
|||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
if num_experts >= 224:
|
if enable_moe_align_block_size_triton:
|
||||||
if enable_moe_align_block_size_triton:
|
moe_align_block_size_triton(
|
||||||
moe_align_block_size_triton(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
token_cnts_buffer = torch.zeros(
|
|
||||||
(num_experts + 1) * num_experts,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device,
|
|
||||||
)
|
|
||||||
cumsum_buffer = torch.zeros(
|
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
sgl_moe_align_block_size(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vllm_ops.moe_align_block_size(
|
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -495,6 +466,26 @@ def moe_align_block_size(
|
|||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
token_cnts_buffer = torch.zeros(
|
||||||
|
(num_experts + 1) * num_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
cumsum_buffer = torch.zeros(
|
||||||
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
sgl_moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
|
)
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available
|
||||||
@@ -14,6 +13,8 @@ from sglang.srt.utils import is_cuda_available
|
|||||||
_is_cuda_available = is_cuda_available()
|
_is_cuda_available = is_cuda_available()
|
||||||
if _is_cuda_available:
|
if _is_cuda_available:
|
||||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||||
|
else:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user