adaptation w4a8 tp
This commit is contained in:
@@ -170,9 +170,7 @@ class RMSNorm(CustomOp):
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
residual_out,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
@@ -180,7 +178,9 @@ class RMSNorm(CustomOp):
|
||||
return output, residual_out
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
residual_out,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
||||
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
|
||||
import torch
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
@@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
dispatch_output,
|
||||
) :
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
@@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
use_int4_w4a8=True,
|
||||
per_channel_quant=True,
|
||||
activation=layer.moe_runner_config.activation,
|
||||
expert_map=layer.expert_map_gpu,
|
||||
# expert_map=layer.expert_map_gpu,
|
||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||
global_num_experts=layer.moe_runner_config.num_experts,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
|
||||
Reference in New Issue
Block a user