[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)
This commit is contained in:
36
.github/workflows/pr-test-npu.yml
vendored
36
.github/workflows/pr-test-npu.yml
vendored
@@ -127,12 +127,48 @@ jobs:
|
||||
cd test/srt
|
||||
python3 run_suite.py --suite per-commit-4-ascend-npu --timeout-per-file 3600
|
||||
|
||||
per-commit-16-ascend-a3:
|
||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||
github.event.pull_request.draft == false
|
||||
runs-on: linux-aarch64-a3-16
|
||||
container:
|
||||
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# speed up by using infra cache services
|
||||
CACHING_URL="cache-service.nginx-pypi-cache.svc.cluster.local"
|
||||
sed -Ei "s@(ports|archive).ubuntu.com@${CACHING_URL}:8081@g" /etc/apt/sources.list
|
||||
pip config set global.index-url http://${CACHING_URL}/pypi/simple
|
||||
pip config set global.trusted-host ${CACHING_URL}
|
||||
|
||||
bash scripts/ci/npu_ci_install_dependency.sh
|
||||
# copy required file from our daily cache
|
||||
cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp
|
||||
# copy download through proxy
|
||||
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
SGLANG_USE_MODELSCOPE: true
|
||||
SGLANG_IS_IN_CI: true
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
TORCH_EXTENSIONS_DIR: /tmp/torch_extensions
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 5400
|
||||
|
||||
pr-test-npu-finish:
|
||||
if: always()
|
||||
needs:
|
||||
- per-commit-1-ascend-npu
|
||||
- per-commit-2-ascend-npu
|
||||
- per-commit-4-ascend-npu
|
||||
- per-commit-16-ascend-a3
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check all dependent job statuses
|
||||
|
||||
@@ -72,5 +72,6 @@ jobs:
|
||||
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
|
||||
provenance: false
|
||||
build-args: |
|
||||
SGLANG_KERNEL_NPU_TAG=20250901
|
||||
CANN_VERSION=${{ matrix.cann_version }}
|
||||
DEVICE_TYPE=${{ matrix.device_type }}
|
||||
|
||||
4
.github/workflows/release-docker-npu.yml
vendored
4
.github/workflows/release-docker-npu.yml
vendored
@@ -54,8 +54,6 @@ jobs:
|
||||
run: |
|
||||
version=$(cat python/sglang/version.py | cut -d'"' -f2)
|
||||
echo "TAG=lmsysorg/sglang:v$version-cann${{ matrix.cann_version }}-${{ matrix.device_type }}" >> $GITHUB_OUTPUT
|
||||
kernel_tag=$(curl -s https://api.github.com/repos/sgl-project/sgl-kernel-npu/tags | jq -r '.[0].name')
|
||||
echo "KERNEL_NPU_TAG=${kernel_tag}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build and push Docker image
|
||||
id: build-and-push
|
||||
@@ -70,6 +68,6 @@ jobs:
|
||||
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
|
||||
provenance: false
|
||||
build-args: |
|
||||
SGLANG_KERNEL_NPU_TAG=${{ steps.get_version.outputs.KERNEL_NPU_TAG }}
|
||||
SGLANG_KERNEL_NPU_TAG=20250901
|
||||
CANN_VERSION=${{ matrix.cann_version }}
|
||||
DEVICE_TYPE=${{ matrix.device_type }}
|
||||
|
||||
@@ -55,7 +55,7 @@ class EPLBManager:
|
||||
enable_timing = self._rebalance_layers_per_chunk is None
|
||||
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
torch.get_device_module().synchronize()
|
||||
time_start = time.time()
|
||||
|
||||
dump_record_output = get_global_expert_distribution_recorder().dump_record(
|
||||
@@ -85,7 +85,7 @@ class EPLBManager:
|
||||
|
||||
msg = f"[EPLBManager] rebalance end"
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
torch.get_device_module().synchronize()
|
||||
time_end = time.time()
|
||||
msg += f" time={time_end - time_start:.3f}s"
|
||||
logger.info(msg)
|
||||
|
||||
@@ -30,7 +30,9 @@ import torch.distributed
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
||||
def _on_hook(self, hook_name: str, **kwargs):
|
||||
if self._disable_all:
|
||||
return
|
||||
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
||||
if not (
|
||||
self._recording or torch.get_device_module().is_current_stream_capturing()
|
||||
):
|
||||
return
|
||||
gatherer = self._single_pass_gatherers[
|
||||
self._accumulator.get_single_pass_gatherer_key(
|
||||
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
|
||||
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not _is_npu:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "npu"
|
||||
self._enable_global_physical_experts = enable_global_physical_experts
|
||||
self._data = torch.zeros(
|
||||
(
|
||||
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
),
|
||||
),
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
|
||||
if self._first_dump:
|
||||
self._first_dump = False
|
||||
torch.cuda.empty_cache()
|
||||
torch.get_device_module().empty_cache()
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
||||
|
||||
@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
|
||||
):
|
||||
if self._first_execution:
|
||||
self._first_execution = False
|
||||
torch.cuda.empty_cache()
|
||||
torch.get_device_module().empty_cache()
|
||||
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
assert old_expert_location_metadata is not None
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
@@ -33,6 +34,7 @@ class ForwardMetadata:
|
||||
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_list: Optional[List[int]] = None
|
||||
seq_lens_list_cumsum: Optional[List[int]] = None
|
||||
|
||||
|
||||
class AscendAttnBackend(AttentionBackend):
|
||||
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
tp_size = get_attention_tp_size()
|
||||
self.forward_metadata = ForwardMetadata()
|
||||
|
||||
self.forward_metadata.block_tables = (
|
||||
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
|
||||
forward_batch.extend_seq_lens.cpu().int()
|
||||
)
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
|
||||
forward_batch.extend_seq_lens_cpu
|
||||
)
|
||||
|
||||
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.is_extend_in_batch:
|
||||
seq_lens_list_cumsum[-1] = (
|
||||
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
||||
) * tp_size
|
||||
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
||||
|
||||
self.graph_mode = False
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
DispatchOutput,
|
||||
@@ -454,7 +453,7 @@ class DeepEPMoE(EPMoE):
|
||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||
return self.forward_aiter(dispatch_output)
|
||||
if _is_npu:
|
||||
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
|
||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||
return self.forward_npu(dispatch_output)
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
@@ -718,63 +717,124 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
||||
):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
||||
output_dtype = torch.bfloat16
|
||||
|
||||
pertoken_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
group_list_type = 1
|
||||
seg_indptr = seg_indptr.to(torch.int64)
|
||||
|
||||
import torch_npu
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=seg_indptr,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=seg_indptr,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
||||
output_dtype = torch.bfloat16
|
||||
group_list_type = 1
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=seg_indptr,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
||||
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
||||
|
||||
return hidden_states
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
else:
|
||||
# dynamic quant
|
||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||
hidden_states.device
|
||||
)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
||||
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
group_list = group_list.to(torch.int64)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=per_token_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
return _forward_normal(dispatch_output)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
return _forward_ll(dispatch_output)
|
||||
else:
|
||||
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
||||
|
||||
|
||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||
|
||||
@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPConfig,
|
||||
DeepEPDispatcher,
|
||||
DeepEPLLCombineInput,
|
||||
@@ -23,7 +22,6 @@ from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AscendDeepEPLLOutput",
|
||||
"BaseDispatcher",
|
||||
"BaseDispatcherConfig",
|
||||
"CombineInput",
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPLLCombineInput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalCombineInput,
|
||||
@@ -47,19 +46,12 @@ class DispatchOutputChecker:
|
||||
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
|
||||
return dispatch_output.format.is_deepep()
|
||||
|
||||
@staticmethod
|
||||
def format_is_ascent_ll(
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> TypeGuard[AscendDeepEPLLOutput]:
|
||||
return dispatch_output.format.is_ascent_ll()
|
||||
|
||||
|
||||
class DispatchOutputFormat(Enum):
|
||||
|
||||
STANDARD = "standard"
|
||||
DEEPEP_NORMAL = "deepep_normal"
|
||||
DEEPEP_LL = "deepep_ll"
|
||||
ASCENT_LL = "ascent_ll"
|
||||
|
||||
def is_standard(self) -> bool:
|
||||
return self == DispatchOutputFormat.STANDARD
|
||||
@@ -76,9 +68,6 @@ class DispatchOutputFormat(Enum):
|
||||
DispatchOutputFormat.DEEPEP_LL,
|
||||
]
|
||||
|
||||
def is_ascent_ll(self) -> bool:
|
||||
return self == DispatchOutputFormat.ASCENT_LL
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DispatchOutput(Protocol):
|
||||
|
||||
@@ -77,24 +77,8 @@ class DeepEPLLOutput(NamedTuple):
|
||||
return DispatchOutputFormat.DEEPEP_LL
|
||||
|
||||
|
||||
class AscendDeepEPLLOutput(NamedTuple):
|
||||
"""AscendDeepEP low latency dispatch output."""
|
||||
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
seg_indptr: torch.Tensor
|
||||
expected_m: int
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.ASCENT_LL
|
||||
|
||||
|
||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
||||
|
||||
|
||||
class DeepEPNormalCombineInput(NamedTuple):
|
||||
@@ -434,12 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_post_reorder_triton_kernel,
|
||||
)
|
||||
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
else:
|
||||
if hidden_states.shape[0] > 0:
|
||||
@@ -553,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
masked_m
|
||||
)
|
||||
|
||||
if _is_npu:
|
||||
deepep_output = AscendDeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
self.handle[1],
|
||||
expected_m,
|
||||
)
|
||||
else:
|
||||
deepep_output = DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
deepep_output = DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
return deepep_output
|
||||
|
||||
def _dispatch_core(
|
||||
|
||||
@@ -330,6 +330,14 @@ class TopK(CustomOp):
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
if expert_location_dispatch_info is not None:
|
||||
topk_ids = topk_ids_logical_to_physical(
|
||||
topk_ids, expert_location_dispatch_info
|
||||
)
|
||||
get_global_expert_distribution_recorder().on_select_experts(
|
||||
topk_ids=topk_ids
|
||||
)
|
||||
|
||||
return StandardTopKOutput(topk_weights, topk_ids, _)
|
||||
else:
|
||||
self.topk_config.torch_native = True
|
||||
|
||||
@@ -51,5 +51,11 @@ ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil
|
||||
wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}"
|
||||
|
||||
|
||||
### Install sgl-kernel-npu
|
||||
SGL_KERNEL_NPU_TAG="20250901"
|
||||
git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG}
|
||||
(cd sgl-kernel-npu && bash ./build.sh -a deepep && pip install output/deep_ep*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so)
|
||||
|
||||
|
||||
### Install SGLang
|
||||
${PIP_INSTALL} -v -e "python[srt_npu]"
|
||||
|
||||
121
test/srt/ascend/test_ascend_deepep.py
Normal file
121
test/srt/ascend/test_ascend_deepep.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8": {
|
||||
"accuracy": 0.95,
|
||||
"latency": 1000,
|
||||
"output_throughput": 6,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendDeepEP(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--quantization",
|
||||
"w8a8_int8",
|
||||
"--mem-fraction-static",
|
||||
0.9,
|
||||
"--max-running-requests",
|
||||
32,
|
||||
"--disable-radix-cache",
|
||||
"--chunked-prefill-size",
|
||||
32768,
|
||||
"--disable-cuda-graph",
|
||||
"--tp-size",
|
||||
16,
|
||||
"--dp-size",
|
||||
1,
|
||||
"--ep-size",
|
||||
16,
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--deepep-mode",
|
||||
"auto",
|
||||
]
|
||||
|
||||
cls.extra_envs = {
|
||||
"HCCL_BUFFSIZE": "500",
|
||||
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "32",
|
||||
}
|
||||
os.environ.update(cls.extra_envs)
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=1500,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -300,6 +300,9 @@ suite_ascend = {
|
||||
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
|
||||
TestFile("ascend/test_ascend_tp4_bf16.py", 400),
|
||||
],
|
||||
"per-commit-16-ascend-a3": [
|
||||
TestFile("ascend/test_ascend_deepep.py", 400),
|
||||
],
|
||||
}
|
||||
|
||||
suites.update(suite_amd)
|
||||
|
||||
Reference in New Issue
Block a user