[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)
This commit is contained in:
36
.github/workflows/pr-test-npu.yml
vendored
36
.github/workflows/pr-test-npu.yml
vendored
@@ -127,12 +127,48 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite per-commit-4-ascend-npu --timeout-per-file 3600
|
python3 run_suite.py --suite per-commit-4-ascend-npu --timeout-per-file 3600
|
||||||
|
|
||||||
|
per-commit-16-ascend-a3:
|
||||||
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
|
github.event.pull_request.draft == false
|
||||||
|
runs-on: linux-aarch64-a3-16
|
||||||
|
container:
|
||||||
|
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
# speed up by using infra cache services
|
||||||
|
CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local"
|
||||||
|
sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list
|
||||||
|
pip config set global.index-url http://${CACHING_URL}/pypi/simple
|
||||||
|
pip config set global.trusted-host ${CACHING_URL}
|
||||||
|
|
||||||
|
bash scripts/ci/npu_ci_install_dependency.sh
|
||||||
|
# copy required file from our daily cache
|
||||||
|
cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp
|
||||||
|
# copy download through proxy
|
||||||
|
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||||
|
|
||||||
|
- name: Run test
|
||||||
|
timeout-minutes: 90
|
||||||
|
env:
|
||||||
|
SGLANG_USE_MODELSCOPE: true
|
||||||
|
SGLANG_IS_IN_CI: true
|
||||||
|
HF_ENDPOINT: https://hf-mirror.com
|
||||||
|
TORCH_EXTENSIONS_DIR: /tmp/torch_extensions
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 5400
|
||||||
|
|
||||||
pr-test-npu-finish:
|
pr-test-npu-finish:
|
||||||
if: always()
|
if: always()
|
||||||
needs:
|
needs:
|
||||||
- per-commit-1-ascend-npu
|
- per-commit-1-ascend-npu
|
||||||
- per-commit-2-ascend-npu
|
- per-commit-2-ascend-npu
|
||||||
- per-commit-4-ascend-npu
|
- per-commit-4-ascend-npu
|
||||||
|
- per-commit-16-ascend-a3
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check all dependent job statuses
|
- name: Check all dependent job statuses
|
||||||
|
|||||||
@@ -72,5 +72,6 @@ jobs:
|
|||||||
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
|
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
|
||||||
provenance: false
|
provenance: false
|
||||||
build-args: |
|
build-args: |
|
||||||
|
SGLANG_KERNEL_NPU_TAG=20250901
|
||||||
CANN_VERSION=${{ matrix.cann_version }}
|
CANN_VERSION=${{ matrix.cann_version }}
|
||||||
DEVICE_TYPE=${{ matrix.device_type }}
|
DEVICE_TYPE=${{ matrix.device_type }}
|
||||||
|
|||||||
4
.github/workflows/release-docker-npu.yml
vendored
4
.github/workflows/release-docker-npu.yml
vendored
@@ -54,8 +54,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
version=$(cat python/sglang/version.py | cut -d'"' -f2)
|
version=$(cat python/sglang/version.py | cut -d'"' -f2)
|
||||||
echo "TAG=lmsysorg/sglang:v$version-cann${{ matrix.cann_version }}-${{ matrix.device_type }}" >> $GITHUB_OUTPUT
|
echo "TAG=lmsysorg/sglang:v$version-cann${{ matrix.cann_version }}-${{ matrix.device_type }}" >> $GITHUB_OUTPUT
|
||||||
kernel_tag=$(curl -s https://api.github.com/repos/sgl-project/sgl-kernel-npu/tags | jq -r '.[0].name')
|
|
||||||
echo "KERNEL_NPU_TAG=${kernel_tag}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
id: build-and-push
|
id: build-and-push
|
||||||
@@ -70,6 +68,6 @@ jobs:
|
|||||||
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
|
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
|
||||||
provenance: false
|
provenance: false
|
||||||
build-args: |
|
build-args: |
|
||||||
SGLANG_KERNEL_NPU_TAG=${{ steps.get_version.outputs.KERNEL_NPU_TAG }}
|
SGLANG_KERNEL_NPU_TAG=20250901
|
||||||
CANN_VERSION=${{ matrix.cann_version }}
|
CANN_VERSION=${{ matrix.cann_version }}
|
||||||
DEVICE_TYPE=${{ matrix.device_type }}
|
DEVICE_TYPE=${{ matrix.device_type }}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class EPLBManager:
|
|||||||
enable_timing = self._rebalance_layers_per_chunk is None
|
enable_timing = self._rebalance_layers_per_chunk is None
|
||||||
|
|
||||||
if enable_timing:
|
if enable_timing:
|
||||||
torch.cuda.synchronize()
|
torch.get_device_module().synchronize()
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
|
||||||
dump_record_output = get_global_expert_distribution_recorder().dump_record(
|
dump_record_output = get_global_expert_distribution_recorder().dump_record(
|
||||||
@@ -85,7 +85,7 @@ class EPLBManager:
|
|||||||
|
|
||||||
msg = f"[EPLBManager] rebalance end"
|
msg = f"[EPLBManager] rebalance end"
|
||||||
if enable_timing:
|
if enable_timing:
|
||||||
torch.cuda.synchronize()
|
torch.get_device_module().synchronize()
|
||||||
time_end = time.time()
|
time_end = time.time()
|
||||||
msg += f" time={time_end - time_start:.3f}s"
|
msg += f" time={time_end - time_start:.3f}s"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ import torch.distributed
|
|||||||
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import Withable, get_bool_env_var
|
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||||
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
|||||||
def _on_hook(self, hook_name: str, **kwargs):
|
def _on_hook(self, hook_name: str, **kwargs):
|
||||||
if self._disable_all:
|
if self._disable_all:
|
||||||
return
|
return
|
||||||
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
if not (
|
||||||
|
self._recording or torch.get_device_module().is_current_stream_capturing()
|
||||||
|
):
|
||||||
return
|
return
|
||||||
gatherer = self._single_pass_gatherers[
|
gatherer = self._single_pass_gatherers[
|
||||||
self._accumulator.get_single_pass_gatherer_key(
|
self._accumulator.get_single_pass_gatherer_key(
|
||||||
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
|
|||||||
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||||
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
if not _is_npu:
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "npu"
|
||||||
self._enable_global_physical_experts = enable_global_physical_experts
|
self._enable_global_physical_experts = enable_global_physical_experts
|
||||||
self._data = torch.zeros(
|
self._data = torch.zeros(
|
||||||
(
|
(
|
||||||
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device="cuda",
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
|||||||
|
|
||||||
if self._first_dump:
|
if self._first_dump:
|
||||||
self._first_dump = False
|
self._first_dump = False
|
||||||
torch.cuda.empty_cache()
|
torch.get_device_module().empty_cache()
|
||||||
|
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
|
|||||||
):
|
):
|
||||||
if self._first_execution:
|
if self._first_execution:
|
||||||
self._first_execution = False
|
self._first_execution = False
|
||||||
torch.cuda.empty_cache()
|
torch.get_device_module().empty_cache()
|
||||||
|
|
||||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
assert old_expert_location_metadata is not None
|
assert old_expert_location_metadata is not None
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
|
|||||||
from sglang.srt.configs.model_config import AttentionArch
|
from sglang.srt.configs.model_config import AttentionArch
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.layers.radix_attention import AttentionType
|
from sglang.srt.layers.radix_attention import AttentionType
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import get_bool_env_var
|
from sglang.srt.utils import get_bool_env_var
|
||||||
@@ -33,6 +34,7 @@ class ForwardMetadata:
|
|||||||
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||||
seq_lens_cpu_list: Optional[List[int]] = None
|
seq_lens_cpu_list: Optional[List[int]] = None
|
||||||
|
seq_lens_list_cumsum: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
class AscendAttnBackend(AttentionBackend):
|
class AscendAttnBackend(AttentionBackend):
|
||||||
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
|
tp_size = get_attention_tp_size()
|
||||||
self.forward_metadata = ForwardMetadata()
|
self.forward_metadata = ForwardMetadata()
|
||||||
|
|
||||||
self.forward_metadata.block_tables = (
|
self.forward_metadata.block_tables = (
|
||||||
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
forward_batch.extend_seq_lens.cpu().int()
|
forward_batch.extend_seq_lens.cpu().int()
|
||||||
)
|
)
|
||||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||||
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
|
|
||||||
forward_batch.extend_seq_lens_cpu
|
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
||||||
)
|
if forward_batch.is_extend_in_batch:
|
||||||
|
seq_lens_list_cumsum[-1] = (
|
||||||
|
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
||||||
|
) * tp_size
|
||||||
|
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
||||||
|
|
||||||
self.graph_mode = False
|
self.graph_mode = False
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
AscendDeepEPLLOutput,
|
|
||||||
DeepEPLLOutput,
|
DeepEPLLOutput,
|
||||||
DeepEPNormalOutput,
|
DeepEPNormalOutput,
|
||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
@@ -454,7 +453,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||||
return self.forward_aiter(dispatch_output)
|
return self.forward_aiter(dispatch_output)
|
||||||
if _is_npu:
|
if _is_npu:
|
||||||
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
|
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||||
return self.forward_npu(dispatch_output)
|
return self.forward_npu(dispatch_output)
|
||||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
@@ -718,63 +717,124 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
def forward_npu(
|
def forward_npu(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPLLOutput,
|
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
||||||
):
|
):
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
|
||||||
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.moe_runner_config.activation == "silu"
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
|
||||||
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
|
||||||
output_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
pertoken_scale = hidden_states[1]
|
|
||||||
hidden_states = hidden_states[0]
|
|
||||||
|
|
||||||
group_list_type = 1
|
|
||||||
seg_indptr = seg_indptr.to(torch.int64)
|
|
||||||
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
# gmm1: gate_up_proj
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
|
||||||
x=[hidden_states],
|
|
||||||
weight=[self.w13_weight],
|
|
||||||
split_item=2,
|
|
||||||
group_list_type=group_list_type,
|
|
||||||
group_type=0,
|
|
||||||
group_list=seg_indptr,
|
|
||||||
output_dtype=torch.int32,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# act_fn: swiglu
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
output_dtype = torch.bfloat16
|
||||||
x=hidden_states,
|
group_list_type = 1
|
||||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
|
||||||
activation_scale=pertoken_scale,
|
|
||||||
bias=None,
|
|
||||||
quant_scale=None,
|
|
||||||
quant_offset=None,
|
|
||||||
group_index=seg_indptr,
|
|
||||||
activate_left=True,
|
|
||||||
quant_mode=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# gmm2: down_proj
|
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
if TYPE_CHECKING:
|
||||||
x=[hidden_states],
|
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
||||||
weight=[self.w2_weight],
|
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
||||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
|
||||||
per_token_scale=[swiglu_out_scale],
|
|
||||||
split_item=2,
|
|
||||||
group_list_type=group_list_type,
|
|
||||||
group_type=0,
|
|
||||||
group_list=seg_indptr,
|
|
||||||
output_dtype=output_dtype,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
return hidden_states
|
if isinstance(hidden_states, tuple):
|
||||||
|
per_token_scale = hidden_states[1]
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
else:
|
||||||
|
# dynamic quant
|
||||||
|
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||||
|
hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||||
|
hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# gmm1: gate_up_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w13_weight],
|
||||||
|
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||||
|
per_token_scale=[per_token_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# act_fn: swiglu
|
||||||
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||||
|
|
||||||
|
# gmm2: down_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w2_weight],
|
||||||
|
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||||
|
per_token_scale=[swiglu_out_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||||
|
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
||||||
|
|
||||||
|
per_token_scale = hidden_states[1]
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
|
||||||
|
group_list = group_list.to(torch.int64)
|
||||||
|
|
||||||
|
# gmm1: gate_up_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w13_weight],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=torch.int32,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# act_fn: swiglu
|
||||||
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
|
x=hidden_states,
|
||||||
|
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||||
|
activation_scale=per_token_scale,
|
||||||
|
bias=None,
|
||||||
|
quant_scale=None,
|
||||||
|
quant_offset=None,
|
||||||
|
group_index=group_list,
|
||||||
|
activate_left=True,
|
||||||
|
quant_mode=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# gmm2: down_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w2_weight],
|
||||||
|
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||||
|
per_token_scale=[swiglu_out_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||||
|
return _forward_normal(dispatch_output)
|
||||||
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||||
|
return _forward_ll(dispatch_output)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
||||||
|
|
||||||
|
|
||||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
|||||||
DispatchOutputFormat,
|
DispatchOutputFormat,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||||
AscendDeepEPLLOutput,
|
|
||||||
DeepEPConfig,
|
DeepEPConfig,
|
||||||
DeepEPDispatcher,
|
DeepEPDispatcher,
|
||||||
DeepEPLLCombineInput,
|
DeepEPLLCombineInput,
|
||||||
@@ -23,7 +22,6 @@ from sglang.srt.layers.moe.token_dispatcher.standard import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AscendDeepEPLLOutput",
|
|
||||||
"BaseDispatcher",
|
"BaseDispatcher",
|
||||||
"BaseDispatcherConfig",
|
"BaseDispatcherConfig",
|
||||||
"CombineInput",
|
"CombineInput",
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import torch
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
AscendDeepEPLLOutput,
|
|
||||||
DeepEPLLCombineInput,
|
DeepEPLLCombineInput,
|
||||||
DeepEPLLOutput,
|
DeepEPLLOutput,
|
||||||
DeepEPNormalCombineInput,
|
DeepEPNormalCombineInput,
|
||||||
@@ -47,19 +46,12 @@ class DispatchOutputChecker:
|
|||||||
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
|
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
|
||||||
return dispatch_output.format.is_deepep()
|
return dispatch_output.format.is_deepep()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def format_is_ascent_ll(
|
|
||||||
dispatch_output: DispatchOutput,
|
|
||||||
) -> TypeGuard[AscendDeepEPLLOutput]:
|
|
||||||
return dispatch_output.format.is_ascent_ll()
|
|
||||||
|
|
||||||
|
|
||||||
class DispatchOutputFormat(Enum):
|
class DispatchOutputFormat(Enum):
|
||||||
|
|
||||||
STANDARD = "standard"
|
STANDARD = "standard"
|
||||||
DEEPEP_NORMAL = "deepep_normal"
|
DEEPEP_NORMAL = "deepep_normal"
|
||||||
DEEPEP_LL = "deepep_ll"
|
DEEPEP_LL = "deepep_ll"
|
||||||
ASCENT_LL = "ascent_ll"
|
|
||||||
|
|
||||||
def is_standard(self) -> bool:
|
def is_standard(self) -> bool:
|
||||||
return self == DispatchOutputFormat.STANDARD
|
return self == DispatchOutputFormat.STANDARD
|
||||||
@@ -76,9 +68,6 @@ class DispatchOutputFormat(Enum):
|
|||||||
DispatchOutputFormat.DEEPEP_LL,
|
DispatchOutputFormat.DEEPEP_LL,
|
||||||
]
|
]
|
||||||
|
|
||||||
def is_ascent_ll(self) -> bool:
|
|
||||||
return self == DispatchOutputFormat.ASCENT_LL
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class DispatchOutput(Protocol):
|
class DispatchOutput(Protocol):
|
||||||
|
|||||||
@@ -77,24 +77,8 @@ class DeepEPLLOutput(NamedTuple):
|
|||||||
return DispatchOutputFormat.DEEPEP_LL
|
return DispatchOutputFormat.DEEPEP_LL
|
||||||
|
|
||||||
|
|
||||||
class AscendDeepEPLLOutput(NamedTuple):
|
|
||||||
"""AscendDeepEP low latency dispatch output."""
|
|
||||||
|
|
||||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
topk_idx: torch.Tensor
|
|
||||||
topk_weights: torch.Tensor
|
|
||||||
masked_m: torch.Tensor
|
|
||||||
seg_indptr: torch.Tensor
|
|
||||||
expected_m: int
|
|
||||||
|
|
||||||
@property
|
|
||||||
def format(self) -> DispatchOutputFormat:
|
|
||||||
return DispatchOutputFormat.ASCENT_LL
|
|
||||||
|
|
||||||
|
|
||||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||||
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||||
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepEPNormalCombineInput(NamedTuple):
|
class DeepEPNormalCombineInput(NamedTuple):
|
||||||
@@ -434,12 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
deepep_post_reorder_triton_kernel,
|
deepep_post_reorder_triton_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||||
output = hidden_states
|
output = hidden_states
|
||||||
else:
|
else:
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
@@ -553,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
masked_m
|
masked_m
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_npu:
|
deepep_output = DeepEPLLOutput(
|
||||||
deepep_output = AscendDeepEPLLOutput(
|
hidden_states,
|
||||||
hidden_states,
|
topk_idx,
|
||||||
topk_idx,
|
topk_weights,
|
||||||
topk_weights,
|
masked_m,
|
||||||
masked_m,
|
expected_m,
|
||||||
self.handle[1],
|
)
|
||||||
expected_m,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
deepep_output = DeepEPLLOutput(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
)
|
|
||||||
return deepep_output
|
return deepep_output
|
||||||
|
|
||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
|
|||||||
@@ -330,6 +330,14 @@ class TopK(CustomOp):
|
|||||||
)
|
)
|
||||||
topk_weights = topk_weights / topk_weights_sum
|
topk_weights = topk_weights / topk_weights_sum
|
||||||
|
|
||||||
|
if expert_location_dispatch_info is not None:
|
||||||
|
topk_ids = topk_ids_logical_to_physical(
|
||||||
|
topk_ids, expert_location_dispatch_info
|
||||||
|
)
|
||||||
|
get_global_expert_distribution_recorder().on_select_experts(
|
||||||
|
topk_ids=topk_ids
|
||||||
|
)
|
||||||
|
|
||||||
return StandardTopKOutput(topk_weights, topk_ids, _)
|
return StandardTopKOutput(topk_weights, topk_ids, _)
|
||||||
else:
|
else:
|
||||||
self.topk_config.torch_native = True
|
self.topk_config.torch_native = True
|
||||||
|
|||||||
@@ -51,5 +51,11 @@ ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil
|
|||||||
wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}"
|
wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}"
|
||||||
|
|
||||||
|
|
||||||
|
### Install sgl-kernel-npu
|
||||||
|
SGL_KERNEL_NPU_TAG="20250901"
|
||||||
|
git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG}
|
||||||
|
(cd sgl-kernel-npu && bash ./build.sh -a deepep && pip install output/deep_ep*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so)
|
||||||
|
|
||||||
|
|
||||||
### Install SGLang
|
### Install SGLang
|
||||||
${PIP_INSTALL} -v -e "python[srt_npu]"
|
${PIP_INSTALL} -v -e "python[srt_npu]"
|
||||||
|
|||||||
121
test/srt/ascend/test_ascend_deepep.py
Normal file
121
test/srt/ascend/test_ascend_deepep.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
|
popen_launch_server,
|
||||||
|
run_bench_offline_throughput,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_MODEL_MATRIX = {
|
||||||
|
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8": {
|
||||||
|
"accuracy": 0.95,
|
||||||
|
"latency": 1000,
|
||||||
|
"output_throughput": 6,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendDeepEP(CustomTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.models = TEST_MODEL_MATRIX.keys()
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||||
|
|
||||||
|
cls.common_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--attention-backend",
|
||||||
|
"ascend",
|
||||||
|
"--quantization",
|
||||||
|
"w8a8_int8",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.9,
|
||||||
|
"--max-running-requests",
|
||||||
|
32,
|
||||||
|
"--disable-radix-cache",
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
32768,
|
||||||
|
"--disable-cuda-graph",
|
||||||
|
"--tp-size",
|
||||||
|
16,
|
||||||
|
"--dp-size",
|
||||||
|
1,
|
||||||
|
"--ep-size",
|
||||||
|
16,
|
||||||
|
"--moe-a2a-backend",
|
||||||
|
"deepep",
|
||||||
|
"--deepep-mode",
|
||||||
|
"auto",
|
||||||
|
]
|
||||||
|
|
||||||
|
cls.extra_envs = {
|
||||||
|
"HCCL_BUFFSIZE": "500",
|
||||||
|
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "32",
|
||||||
|
}
|
||||||
|
os.environ.update(cls.extra_envs)
|
||||||
|
|
||||||
|
def test_a_gsm8k(self):
|
||||||
|
for model in self.models:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
print(f"##=== Testing accuracy: {model} ===##")
|
||||||
|
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
timeout=1500,
|
||||||
|
other_args=[
|
||||||
|
*self.common_args,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=1319,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host=f"http://{self.url.hostname}",
|
||||||
|
port=int(self.url.port),
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
metrics["accuracy"],
|
||||||
|
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
def test_b_throughput(self):
|
||||||
|
for model in self.models:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
print(f"##=== Testing throughput: {model} ===##")
|
||||||
|
|
||||||
|
output_throughput = run_bench_offline_throughput(
|
||||||
|
model,
|
||||||
|
[
|
||||||
|
*self.common_args,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
self.assertGreater(
|
||||||
|
output_throughput,
|
||||||
|
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -300,6 +300,9 @@ suite_ascend = {
|
|||||||
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
|
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
|
||||||
TestFile("ascend/test_ascend_tp4_bf16.py", 400),
|
TestFile("ascend/test_ascend_tp4_bf16.py", 400),
|
||||||
],
|
],
|
||||||
|
"per-commit-16-ascend-a3": [
|
||||||
|
TestFile("ascend/test_ascend_deepep.py", 400),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
suites.update(suite_amd)
|
suites.update(suite_amd)
|
||||||
|
|||||||
Reference in New Issue
Block a user