[RL] fix register the same ops multiple times (#9564)
This commit is contained in:
@@ -146,27 +146,21 @@ def _quant_dequant_mxfp4_fake(
|
|||||||
return torch.empty_like(x)
|
return torch.empty_like(x)
|
||||||
|
|
||||||
|
|
||||||
try:
|
direct_register_custom_op(
|
||||||
direct_register_custom_op(
|
op_name="dequant_mxfp4",
|
||||||
op_name="dequant_mxfp4",
|
op_func=_dequant_mxfp4,
|
||||||
op_func=_dequant_mxfp4,
|
mutates_args=[],
|
||||||
mutates_args=[],
|
fake_impl=_dequant_mxfp4_fake,
|
||||||
fake_impl=_dequant_mxfp4_fake,
|
)
|
||||||
)
|
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
||||||
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
|
||||||
except AttributeError as error:
|
|
||||||
raise error
|
|
||||||
|
|
||||||
try:
|
direct_register_custom_op(
|
||||||
direct_register_custom_op(
|
op_name="quant_dequant_mxfp4",
|
||||||
op_name="quant_dequant_mxfp4",
|
op_func=_quant_dequant_mxfp4,
|
||||||
op_func=_quant_dequant_mxfp4,
|
mutates_args=[],
|
||||||
mutates_args=[],
|
fake_impl=_quant_dequant_mxfp4_fake,
|
||||||
fake_impl=_quant_dequant_mxfp4_fake,
|
)
|
||||||
)
|
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
||||||
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
|
||||||
except AttributeError as error:
|
|
||||||
raise error
|
|
||||||
|
|
||||||
|
|
||||||
class Mxfp4Config(QuantizationConfig):
|
class Mxfp4Config(QuantizationConfig):
|
||||||
|
|||||||
@@ -1665,9 +1665,29 @@ def direct_register_custom_op(
|
|||||||
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
||||||
library object. If you want to bind the operator to a different library,
|
library object. If you want to bind the operator to a different library,
|
||||||
make sure the library object is alive when the operator is used.
|
make sure the library object is alive when the operator is used.
|
||||||
|
|
||||||
|
Note: This function will silently skip registration if the operator
|
||||||
|
with the same name is already registered to avoid RuntimeError in
|
||||||
|
multi-engine scenarios (e.g., VERL framework).
|
||||||
"""
|
"""
|
||||||
import torch.library
|
import torch.library
|
||||||
|
|
||||||
|
my_lib = target_lib or sglang_lib
|
||||||
|
|
||||||
|
# Check if operator is already registered to avoid duplicate registration
|
||||||
|
# This is important for scenarios where multiple SGLang engines run in the same process
|
||||||
|
try:
|
||||||
|
# Try to access the operator to see if it's already registered
|
||||||
|
lib_name = my_lib.m.name if hasattr(my_lib.m, "name") else "sglang"
|
||||||
|
if hasattr(torch.ops, lib_name) and hasattr(
|
||||||
|
getattr(torch.ops, lib_name), op_name
|
||||||
|
):
|
||||||
|
# Operator already exists, skip registration
|
||||||
|
return
|
||||||
|
except (AttributeError, RuntimeError):
|
||||||
|
# Operator doesn't exist, proceed with registration
|
||||||
|
pass
|
||||||
|
|
||||||
if hasattr(torch.library, "infer_schema"):
|
if hasattr(torch.library, "infer_schema"):
|
||||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||||
else:
|
else:
|
||||||
@@ -1676,11 +1696,22 @@ def direct_register_custom_op(
|
|||||||
|
|
||||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||||
|
|
||||||
my_lib = target_lib or sglang_lib
|
try:
|
||||||
my_lib.define(op_name + schema_str)
|
my_lib.define(op_name + schema_str)
|
||||||
my_lib.impl(op_name, op_func, "CUDA")
|
my_lib.impl(op_name, op_func, "CUDA")
|
||||||
if fake_impl is not None:
|
if fake_impl is not None:
|
||||||
my_lib._register_fake(op_name, fake_impl)
|
my_lib._register_fake(op_name, fake_impl)
|
||||||
|
except RuntimeError as error:
|
||||||
|
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
|
||||||
|
# Silently ignore duplicate registration errors
|
||||||
|
# This can happen in multi-engine scenarios
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Re-raise other RuntimeErrors
|
||||||
|
raise error
|
||||||
|
except AttributeError as error:
|
||||||
|
# Always re-raise AttributeError as it indicates missing dependencies
|
||||||
|
raise error
|
||||||
|
|
||||||
|
|
||||||
def set_gpu_proc_affinity(
|
def set_gpu_proc_affinity(
|
||||||
|
|||||||
Reference in New Issue
Block a user