Support DP attention with GPT-OSS (#9359)
This commit is contained in:
@@ -1091,7 +1091,7 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
if name in params_dict.keys():
|
if name in params_dict.keys():
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
if "sinks" in name:
|
if "sinks" in name:
|
||||||
start = tp_rank * param.numel()
|
start = get_attention_tp_rank() * param.numel()
|
||||||
param.data.copy_(
|
param.data.copy_(
|
||||||
loaded_weight[start : start + param.numel()]
|
loaded_weight[start : start + param.numel()]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2183,10 +2183,11 @@ class ServerArgs:
|
|||||||
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
||||||
|
|
||||||
if is_sm100_supported():
|
if is_sm100_supported():
|
||||||
self.enable_flashinfer_allreduce_fusion = True
|
if not self.enable_dp_attention:
|
||||||
logger.info(
|
self.enable_flashinfer_allreduce_fusion = True
|
||||||
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
|
logger.info(
|
||||||
)
|
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
|
||||||
|
)
|
||||||
quantization_config = getattr(hf_config, "quantization_config", None)
|
quantization_config = getattr(hf_config, "quantization_config", None)
|
||||||
is_mxfp4_quant_format = (
|
is_mxfp4_quant_format = (
|
||||||
quantization_config is not None
|
quantization_config is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user