VOOZH about

URL: https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

⇱ Introduction to torch.compile — PyTorch Tutorials 2.12.0+cu130 documentation


Rate this Page
torch.compile">
intermediate/torch_compile_tutorial

Note

Go to the end to download the full example code.

Introduction to torch.compile#

Created On: Mar 15, 2023 | Last Updated: Apr 01, 2026 | Last Verified: Nov 05, 2024

Author: William Wen

torch.compile is the new way to speed up your PyTorch code! torch.compile makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, while requiring minimal code changes.

torch.compile accomplishes this by tracing through your Python code, looking for PyTorch operations. Code that is difficult to trace will result a graph break, which are lost optimization opportunities, rather than errors or silent incorrectness.

torch.compile is available in PyTorch 2.0 and later.

This introduction covers basic torch.compile usage and demonstrates the advantages of torch.compile over our previous PyTorch compiler solution, TorchScript.

For an end-to-end example on a real model, check out our end-to-end torch.compile tutorial.

To troubleshoot issues and to gain a deeper understanding of how to apply torch.compile to your code, check out the torch.compile programming model.

Contents

Required pip dependencies for this tutorial

  • torch >= 2.0

  • numpy

  • scipy

System requirements - A C++ compiler, such as g++ - Python development package (python-devel/python-dev)

Basic Usage#

We turn on some logging to help us to see what torch.compile is doing under the hood in this tutorial. The following code will print out the PyTorch ops that torch.compile traced.

importtorch


torch._logging.set_logs(graph_code=True)

torch.compile is a decorator that takes an arbitrary Python function.

deffoo(x, y):
 a = torch.sin(x)
 b = torch.cos(y)
 return a + b


opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))


@torch.compile
defopt_foo2(x, y):
 a = torch.sin(x)
 b = torch.cos(y)
 return a + b


print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
 ===== __compiled_fn_1_f6719594_96eb_41a9_85b3_ba0aa2142b7d =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
 l_x_ = L_x_
 l_y_ = L_y_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:74 in foo, code: a = torch.sin(x)
 a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:75 in foo, code: b = torch.cos(y)
 b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:76 in foo, code: return a + b
 add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
 return (add,)


tensor([[-0.0239, -0.8411, 0.7322],
 [ 0.8814, 0.4477, -0.7143],
 [-0.0085, 1.2003, 1.5242]])
TRACED GRAPH
 ===== __compiled_fn_3_64334f99_06c4_474a_9f85_2ea2aa3bd40d =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
 l_x_ = L_x_
 l_y_ = L_y_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:85 in opt_foo2, code: a = torch.sin(x)
 a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:86 in opt_foo2, code: b = torch.cos(y)
 b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:87 in opt_foo2, code: return a + b
 add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
 return (add,)


tensor([[ 1.1946, 0.6039, 1.6637],
 [-0.0020, 1.8625, 0.1394],
 [-0.6990, 0.6818, 0.0399]])

torch.compile is applied recursively, so nested function calls within the top-level compiled function will also be compiled.

definner(x):
 return torch.sin(x)


@torch.compile
defouter(x, y):
 a = inner(x)
 b = torch.cos(y)
 return a + b


print(outer(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
 ===== __compiled_fn_5_3f1a571e_4e88_415f_a822_fb7f563d8107 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
 l_x_ = L_x_
 l_y_ = L_y_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:98 in inner, code: return torch.sin(x)
 a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:104 in outer, code: b = torch.cos(y)
 b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:105 in outer, code: return a + b
 add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
 return (add,)


tensor([[ 0.1224, 0.6029, 0.6548],
 [ 0.6954, -0.0926, 0.6567],
 [-0.1284, 0.5222, 0.9308]])

We can also optimize torch.nn.Module instances by either calling its .compile() method or by directly torch.compile-ing the module. This is equivalent to torch.compile-ing the module’s __call__ method (which indirectly calls forward).

t = torch.randn(10, 100)


classMyModule(torch.nn.Module):
 def__init__(self):
 super().__init__()
 self.lin = torch.nn.Linear(3, 3)

 defforward(self, x):
 return torch.nn.functional.relu(self.lin(x))


mod1 = MyModule()
mod1.compile()
print(mod1(torch.randn(3, 3)))

mod2 = MyModule()
mod2 = torch.compile(mod2)
print(mod2(torch.randn(3, 3)))
TRACED GRAPH
 ===== __compiled_fn_7_84ac04cf_d4f0_492d_8428_d1f8f5bd701f =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):
 l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_
 l_self_modules_lin_parameters_bias_ = L_self_modules_lin_parameters_bias_
 l_x_ = L_x_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:126 in forward, code: return torch.nn.functional.relu(self.lin(x))
 linear: "f32[3, 3][3, 1]cpu" = torch._C._nn.linear(l_x_, l_self_modules_lin_parameters_weight_, l_self_modules_lin_parameters_bias_); l_x_ = l_self_modules_lin_parameters_weight_ = l_self_modules_lin_parameters_bias_ = None
 relu: "f32[3, 3][3, 1]cpu" = torch.nn.functional.relu(linear); linear = None
 return (relu,)


tensor([[0.0000, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.0288]], grad_fn=<CompiledFunctionBackward>)
tensor([[0.5102, 0.0020, 0.1122],
 [0.8609, 0.0000, 0.1189],
 [0.6553, 0.0000, 0.2550]], grad_fn=<CompiledFunctionBackward>)

Demonstrating Speedups#

Now let’s demonstrate how torch.compile speeds up a simple PyTorch example. For a demonstration on a more complex model, see our end-to-end torch.compile tutorial.

deffoo3(x):
 y = x + 1
 z = torch.nn.functional.relu(y)
 u = z * 2
 return u


opt_foo3 = torch.compile(foo3)


# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
deftimed(fn):
 start = torch.cuda.Event(enable_timing=True)
 end = torch.cuda.Event(enable_timing=True)
 start.record()
 result = fn()
 end.record()
 torch.cuda.synchronize()
 return result, start.elapsed_time(end) / 1000


inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])
TRACED GRAPH
 ===== __compiled_fn_9_dd03897c_486e_49c1_9808_1ba005bcf46f =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_x_: "f32[4096, 4096][4096, 1]cuda:0"):
 l_x_ = L_x_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:147 in foo3, code: y = x + 1
 y: "f32[4096, 4096][4096, 1]cuda:0" = l_x_ + 1; l_x_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:148 in foo3, code: z = torch.nn.functional.relu(y)
 z: "f32[4096, 4096][4096, 1]cuda:0" = torch.nn.functional.relu(y); y = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:149 in foo3, code: u = z * 2
 u: "f32[4096, 4096][4096, 1]cuda:0" = z * 2; z = None
 return (u,)


compile: 0.5330122680664062
eager: 0.033552383422851564

Notice that torch.compile appears to take a lot longer to complete compared to eager. This is because torch.compile takes extra time to compile the model on the first few executions. torch.compile re-uses compiled code whever possible, so if we run our optimized model several more times, we should see a significant improvement compared to eager.

# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)

eager_times = []
for i in range(10):
 _, eager_time = timed(lambda: foo3(inp))
 eager_times.append(eager_time)
 print(f"eager time {i}: {eager_time}")
print("~" * 10)

compile_times = []
for i in range(10):
 _, compile_time = timed(lambda: opt_foo3(inp))
 compile_times.append(compile_time)
 print(f"compile time {i}: {compile_time}")
print("~" * 10)

importnumpyasnp

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
 f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)
eager time 0: 0.0009297919869422913
eager time 1: 0.0008714240193367004
eager time 2: 0.0008734719753265381
eager time 3: 0.0008724480271339416
eager time 4: 0.0008735039830207825
eager time 5: 0.0008724480271339416
eager time 6: 0.0008734719753265381
eager time 7: 0.0008694400191307068
eager time 8: 0.0008724480271339416
eager time 9: 0.0008704000115394592
~~~~~~~~~~
compile time 0: 0.0005468159914016723
compile time 1: 0.00038707199692726135
compile time 2: 0.0003840000033378601
compile time 3: 0.00038502401113510134
compile time 4: 0.00038092800974845887
compile time 5: 0.00038092800974845887
compile time 6: 0.00038092800974845887
compile time 7: 0.00038092800974845887
compile time 8: 0.0003768320083618164
compile time 9: 0.0003911679983139038
~~~~~~~~~~
(eval) eager median: 0.0008724480271339416, compile median: 0.00038246400654315946, speedup: 2.2811245299117826x
~~~~~~~~~~

And indeed, we can see that running our model with torch.compile results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.

To see speedups on a real model, check out our end-to-end torch.compile tutorial.

Benefits over TorchScript#

Why should we use torch.compile over TorchScript? Primarily, the advantage of torch.compile lies in its ability to handle arbitrary Python code with minimal changes to existing code.

Compare to TorchScript, which has a tracing mode (torch.jit.trace) and a scripting mode (torch.jit.script). Tracing mode is susceptible to silent incorrectness, while scripting mode requires significant code changes and will raise errors on unsupported Python code.

For example, TorchScript tracing silently fails on data-dependent control flow (the if x.sum() < 0: line below) because only the actual control flow path is traced. In comparison, torch.compile is able to correctly handle it.

deff1(x, y):
 if x.sum() < 0:
 return -y
 return y


# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
deftest_fns(fn1, fn2, args):
 out1 = fn1(*args)
 out2 = fn2(*args)
 return torch.allclose(out1, out2)


inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:239: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
 if x.sum() < 0:
traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~

TorchScript scripting can handle data-dependent control flow, but it can require major code changes and will raise errors when unsupported Python is used.

In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument y, an int, does not match with the default argument type, torch.Tensor. In comparison, torch.compile works without requiring any type annotations.

importtracebackastb

torch._logging.set_logs(graph_code=True)


deff2(x, y):
 return x + y


inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
 script_f2(inp1, inp2)
except:
 tb.print_exc()

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
Traceback (most recent call last):
 File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 288, in <module>
 script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
TRACED GRAPH
 ===== __compiled_fn_18_c54b4be8_c287_455d_827b_9517be0e6449 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_x_: "f32[5, 5][5, 1]cpu"):
 l_x_ = L_x_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:280 in f2, code: return x + y
 add: "f32[5, 5][5, 1]cpu" = l_x_ + 3; l_x_ = None
 return (add,)


compile 2: True
~~~~~~~~~~

Graph Breaks#

The graph break is one of the most fundamental concepts within torch.compile. It allows torch.compile to handle arbitrary Python code by interrupting compilation, running the unsupported code, then resuming compilation. The term “graph break” comes from the fact that torch.compile attempts to capture and optimize the PyTorch operation graph. When unsupported Python code is encountered, then this graph must be “broken”. Graph breaks result in lost optimization opportunities, which may still be undesirable, but this is better than silent incorrectness or a hard crash.

Let’s look at a data-dependent control flow example to better see how graph breaks work.

defbar(a, b):
 x = a / (torch.abs(a) + 1)
 if b.sum() < 0:
 b = b * -1
 return x * b


opt_bar = torch.compile(bar)
inp1 = torch.ones(10)
inp2 = torch.ones(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
TRACED GRAPH
 ===== __compiled_fn_20_e2f97a6e_22da_4cd5_b8a2_6e2775a72466 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
 l_a_ = L_a_
 l_b_ = L_b_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
 abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
 add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
 x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
 sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
 lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
 return (lt, x)


TRACED GRAPH
 ===== __compiled_fn_24_76573886_3afe_4b9c_86ab_3d551b5ddef1 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
 l_x_ = L_x_
 l_b_ = L_b_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
 mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
 return (mul,)


TRACED GRAPH
 ===== __compiled_fn_26_8f29a917_800e_4796_bdac_0515df90b592 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
 l_b_ = L_b_
 l_x_ = L_x_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
 b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
 mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
 return (mul_1,)



tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
 0.5000])

The first time we run bar, we see that torch.compile traced 2 graphs corresponding to the following code (noting that b.sum() < 0 is False):

  1. x = a / (torch.abs(a) + 1); b.sum()

  2. return x * b

The second time we run bar, we take the other branch of the if statement and we get 1 traced graph corresponding to the code b = b * -1; return x * b. We do not see a graph of x = a / (torch.abs(a) + 1); b.sum() outputted the second time since torch.compile cached this graph from the first run and re-used it.

Let’s investigate by example how TorchDynamo would step through bar. If b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if not b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3.

We can see all graph breaks by using torch._logging.set_logs(graph_breaks=True).

# Reset to clear the torch.compile cache
torch._dynamo.reset()
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
TRACED GRAPH
 ===== __compiled_fn_28_6428dc02_4d78_4786_a55f_e9a3c887ec99 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
 l_a_ = L_a_
 l_b_ = L_b_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
 abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
 add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
 x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
 sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
 lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
 return (lt, x)


TRACED GRAPH
 ===== __compiled_fn_32_9c42dab1_5001_4559_a841_6d7a30350c6a =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
 l_x_ = L_x_
 l_b_ = L_b_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
 mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
 return (mul,)


TRACED GRAPH
 ===== __compiled_fn_34_6eb5e265_84b5_4c70_9692_a5ea36b44d66 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
 l_b_ = L_b_
 l_x_ = L_x_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
 b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
 mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
 return (mul_1,)



tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
 0.5000])

In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using fullgraph=True:

# Reset to clear the torch.compile cache
torch._dynamo.reset()

opt_bar_fullgraph = torch.compile(bar, fullgraph=True)
try:
 opt_bar_fullgraph(torch.randn(10), torch.randn(10))
except:
 tb.print_exc()
Traceback (most recent call last):
 File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 360, in <module>
 opt_bar_fullgraph(torch.randn(10), torch.randn(10))
 File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1058, in compile_wrapper
 raise e.with_traceback(None) from e.__cause__ # User compiler error
torch._dynamo.exc.Unsupported: Data-dependent branching
 Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.

 The branch condition involves a tensor computed as follows:
 # File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar, code: if b.sum() < 0:
 lt = lt(sum_1, 0)

 Hint: The branch condition uses a scalar integer tensor. Consider rewriting the computation to use plain Python ints (e.g. use int attributes instead of tensor buffers) so the condition becomes a shape guard instead of data-dependent branching.
 Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
 Hint: Use `torch.cond` to express dynamic control flow.

 Developer debug context: attempted to jump with TensorVariable()

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html

from user code:
 File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar
 if b.sum() < 0:

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

In our example above, we can work around this graph break by replacing the if statement with a torch.cond:

fromfunctorch.experimental.control_flowimport cond


@torch.compile(fullgraph=True)
defbar_fixed(a, b):
 x = a / (torch.abs(a) + 1)

 deftrue_branch(y):
 return y * -1

 deffalse_branch(y):
 # NOTE: torch.cond doesn't allow aliased outputs
 return y.clone()

 b = cond(b.sum() < 0, true_branch, false_branch, (b,))
 return x * b


bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)
TRACED GRAPH
 ===== __compiled_fn_37_e3f9f2b7_5488_4ebf_a7e0_8dd0cf56f2c7 =====
 /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
 def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
 l_a_ = L_a_
 l_b_ = L_b_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:373 in bar_fixed, code: x = a / (torch.abs(a) + 1)
 abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
 add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
 x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:382 in bar_fixed, code: b = cond(b.sum() < 0, true_branch, false_branch, (b,))
 sum_1: "f32[][]cpu" = l_b_.sum()
 lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
 cond_true_0 = self.cond_true_0
 cond_false_0 = self.cond_false_0
 cond = torch.ops.higher_order.cond(lt, cond_true_0, cond_false_0, (l_b_,)); lt = cond_true_0 = cond_false_0 = l_b_ = None
 b: "f32[10][1]cpu" = cond[0]; cond = None

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:383 in bar_fixed, code: return x * b
 mul: "f32[10][1]cpu" = x * b; x = b = None
 return (mul,)

 class cond_true_0(torch.nn.Module):
 def forward(self, l_b_: "f32[10][1]cpu"):
 l_b__1 = l_b_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:376 in true_branch, code: return y * -1
 mul: "f32[10][1]cpu" = l_b__1 * -1; l_b__1 = None
 return (mul,)

 class cond_false_0(torch.nn.Module):
 def forward(self, l_b_: "f32[10][1]cpu"):
 l_b__1 = l_b_

 # File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:380 in false_branch, code: return y.clone()
 clone: "f32[10][1]cpu" = l_b__1.clone(); l_b__1 = None
 return (clone,)



tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
 0.5000])

In order to serialize graphs or to run graphs on different (i.e. Python-less) environments, consider using torch.export instead (from PyTorch 2.1+). One important restriction is that torch.export does not support graph breaks. Please check the torch.export tutorial for more details on torch.export.

Check out our section on graph breaks in the torch.compile programming model for tips on how to work around graph breaks.

Troubleshooting#

Is torch.compile failing to speed up your model? Is compile time unreasonably long? Is your code recompiling excessively? Are you having difficulties dealing with graph breaks? Are you looking for tips on how to best use torch.compile? Or maybe you simply want to learn more about the inner workings of torch.compile?

Check out the torch.compile programming model.

Conclusion#

In this tutorial, we introduced torch.compile by covering basic usage, demonstrating speedups over eager mode, comparing to TorchScript, and briefly describing graph breaks.

For an end-to-end example on a real model, check out our end-to-end torch.compile tutorial.

To troubleshoot issues and to gain a deeper understanding of how to apply torch.compile to your code, check out the torch.compile programming model.

We hope that you will give torch.compile a try!

Total running time of the script: (0 minutes 16.976 seconds)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources

To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.