torch compile

torch.compile

TorchDynamo, AOT Autograd, PrimTorch, and TorchInductor
image
https://pytorch.org/get-started/pytorch-2.0/#user-experience
image

TorchDynamo

torch.compile -> TorchDynamo -> FX graphs
torch dynamo通过jit方式将任意python代码编译为FX graphs,允许不同后端进一步进行优化。
torch._dynamo.optimize()
image

torch compile 代码解析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def compile(model: Optional[Callable] = None, *,
fullgraph: builtins.bool = False,
dynamic: Optional[builtins.bool] = None,
backend: Union[str, Callable] = "inductor",
mode: Union[str, None] = None,
options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False) -> Callable:
mode = "default"
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
backend = _TorchCompileWrapper(backend, mode, options, dynamic)

return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)