adaptation w4a8 tp
This commit is contained in:
@@ -170,9 +170,7 @@ class RMSNorm(CustomOp):
|
|||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
residual_out = torch.empty_like(x)
|
residual_out = torch.empty_like(x)
|
||||||
fused_add_rms_norm(
|
fused_add_rms_norm(
|
||||||
output,
|
|
||||||
x,
|
x,
|
||||||
residual_out,
|
|
||||||
residual,
|
residual,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
@@ -180,7 +178,9 @@ class RMSNorm(CustomOp):
|
|||||||
return output, residual_out
|
return output, residual_out
|
||||||
except TypeError:
|
except TypeError:
|
||||||
fused_add_rms_norm(
|
fused_add_rms_norm(
|
||||||
|
output,
|
||||||
x,
|
x,
|
||||||
|
residual_out,
|
||||||
residual,
|
residual,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
|
||||||
import torch
|
import torch
|
||||||
from sglang.srt import _custom_ops as ops
|
from sglang.srt import _custom_ops as ops
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
@@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
|||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
dispatch_output: StandardDispatchOutput,
|
dispatch_output,
|
||||||
) -> CombineInput:
|
) :
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||||
x = dispatch_output.hidden_states
|
x = dispatch_output.hidden_states
|
||||||
topk_output = dispatch_output.topk_output
|
topk_output = dispatch_output.topk_output
|
||||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
@@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
|
|||||||
use_int4_w4a8=True,
|
use_int4_w4a8=True,
|
||||||
per_channel_quant=True,
|
per_channel_quant=True,
|
||||||
activation=layer.moe_runner_config.activation,
|
activation=layer.moe_runner_config.activation,
|
||||||
expert_map=layer.expert_map_gpu,
|
# expert_map=layer.expert_map_gpu,
|
||||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||||
global_num_experts=layer.moe_runner_config.num_experts,
|
global_num_experts=layer.moe_runner_config.num_experts,
|
||||||
w1_scale=(layer.w13_weight_scale),
|
w1_scale=(layer.w13_weight_scale),
|
||||||
|
|||||||
Reference in New Issue
Block a user