Fix torch compile for deepseek-v2 (#1442)
This commit is contained in:
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
# NOTE: FusedMoE torch native implementaiton is not efficient
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
continue
|
||||
if reverse:
|
||||
sub._forward_method = sub.forward_cuda
|
||||
setattr(sub, "is_torch_compile", False)
|
||||
@@ -105,7 +108,15 @@ class CudaGraphRunner:
|
||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
|
||||
self.compile_bs = (
|
||||
[
|
||||
bs
|
||||
for bs in self.capture_bs
|
||||
if bs <= self.model_runner.server_args.max_torch_compile_bs
|
||||
]
|
||||
if self.use_torch_compile
|
||||
else []
|
||||
)
|
||||
|
||||
# Common inputs
|
||||
self.max_bs = max(self.capture_bs)
|
||||
|
||||
@@ -653,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@@ -110,6 +110,7 @@ class ServerArgs:
|
||||
disable_custom_all_reduce: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
max_torch_compile_bs: int = 32
|
||||
torchao_config: str = ""
|
||||
enable_p2p_check: bool = False
|
||||
enable_mla: bool = False
|
||||
@@ -523,6 +524,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Optimize the model with torch.compile. Experimental feature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-torch-compile-bs",
|
||||
type=int,
|
||||
default=ServerArgs.max_torch_compile_bs,
|
||||
help="Set the maximum batch size when using torch compile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torchao-config",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user