[RL] fix register the same ops multiple times (#9564)
This commit is contained in:
@@ -146,27 +146,21 @@ def _quant_dequant_mxfp4_fake(
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
@@ -1665,9 +1665,29 @@ def direct_register_custom_op(
|
||||
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
||||
library object. If you want to bind the operator to a different library,
|
||||
make sure the library object is alive when the operator is used.
|
||||
|
||||
Note: This function will silently skip registration if the operator
|
||||
with the same name is already registered to avoid RuntimeError in
|
||||
multi-engine scenarios (e.g., VERL framework).
|
||||
"""
|
||||
import torch.library
|
||||
|
||||
my_lib = target_lib or sglang_lib
|
||||
|
||||
# Check if operator is already registered to avoid duplicate registration
|
||||
# This is important for scenarios where multiple SGLang engines run in the same process
|
||||
try:
|
||||
# Try to access the operator to see if it's already registered
|
||||
lib_name = my_lib.m.name if hasattr(my_lib.m, "name") else "sglang"
|
||||
if hasattr(torch.ops, lib_name) and hasattr(
|
||||
getattr(torch.ops, lib_name), op_name
|
||||
):
|
||||
# Operator already exists, skip registration
|
||||
return
|
||||
except (AttributeError, RuntimeError):
|
||||
# Operator doesn't exist, proceed with registration
|
||||
pass
|
||||
|
||||
if hasattr(torch.library, "infer_schema"):
|
||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||
else:
|
||||
@@ -1676,11 +1696,22 @@ def direct_register_custom_op(
|
||||
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
|
||||
my_lib = target_lib or sglang_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
try:
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
except RuntimeError as error:
|
||||
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
|
||||
# Silently ignore duplicate registration errors
|
||||
# This can happen in multi-engine scenarios
|
||||
pass
|
||||
else:
|
||||
# Re-raise other RuntimeErrors
|
||||
raise error
|
||||
except AttributeError as error:
|
||||
# Always re-raise AttributeError as it indicates missing dependencies
|
||||
raise error
|
||||
|
||||
|
||||
def set_gpu_proc_affinity(
|
||||
|
||||
Reference in New Issue
Block a user