Torch.export Error: Dynamic Slice With Dim And Step > 1

Alex Johnson
-
Torch.export Error: Dynamic Slice With Dim And Step > 1

Introduction

Hey guys! Today, we're diving deep into a tricky bug encountered while using torch.export in PyTorch. Specifically, this issue arises when you're trying to export a model that slices an input tensor with dynamic shapes along a dynamic dimension, and you're using Dim("t", min=, max=) to define the dynamic shape. Let's break down the problem and see how we can work around it.

The Bug: Slicing Dynamic Dimensions with Steps

The core of the problem lies in how torch.export handles slicing operations on tensors with dynamic shapes, especially when a step size greater than 1 is involved. When you define a dynamic dimension using Dim("t", min=, max=), and then attempt to slice that dimension with a step (e.g., input1[::9, :]), torch.export might throw a ConstraintViolationError. This error essentially means that the tool can't guarantee that all possible values within the specified dynamic range will satisfy the constraints imposed by the slicing operation.

Code Example

To illustrate this, consider the following code snippet:

import torch
from torch.export import export, save, Dim
import torch.nn as nn

class CustomModel(nn.Module):
 def __init__(self):
 super(CustomModel, self).__init__()
 def forward(self, input1):
 out1 = input1[::9, :]
 return out1

model = CustomModel()
model.eval()
input1 = torch.randn(41, 6)

exported_model = torch.export.export(
 model,
 (input1,),
 dynamic_shapes={"input1": {0: Dim("t1", min=1, max=1000)},}
)

exportedmodel_filename = "exported_model_reduced.pt2"
save(exported_model, exportedmodel_filename)

When you run this code, you'll likely encounter the dreaded ConstraintViolationError. The error message will tell you that not all values of t1 (the dynamic dimension) within the range 1 <= t1 <= 1000 satisfy the generated guard condition ((8 + L['input1'].size()[0]) // 9) != 1.

Why Does This Happen?

The reason for this error is that torch.export tries to create guard conditions to ensure that the slicing operation is valid for all possible shapes within the dynamic range. When you use a step size greater than 1, the resulting shape after slicing depends on the exact value of the dynamic dimension. The tool struggles to create a single, universal guard condition that works for all values in the specified range. In simpler terms, PyTorch is trying to figure out if the slice input1[::9, :] will always produce a valid tensor, no matter the size of the input.

Workaround: Using Dim.DYNAMIC

Fortunately, there's a simple workaround for this issue. Instead of using Dim("t", min=, max=) to define the dynamic shape, you can use Dim.DYNAMIC. This tells torch.export that the dimension is dynamic, but it doesn't impose any specific minimum or maximum values. This approach seems to sidestep the constraint violation issue.

Updated Code Example

Here's how you can modify the code to use Dim.DYNAMIC:

import torch
from torch.export import export, save, Dim
import torch.nn as nn

class CustomModel(nn.Module):
 def __init__(self):
 super(CustomModel, self).__init__()
 def forward(self, input1):
 out1 = input1[::9, :]
 return out1

model = CustomModel()
model.eval()
input1 = torch.randn(41, 6)

exported_model = torch.export.export(
 model,
 (input1,),
 dynamic_shapes={"input1": {0: Dim.DYNAMIC},}
)

exportedmodel_filename = "exported_model_reduced.pt2"
save(exported_model, exportedmodel_filename)

By using Dim.DYNAMIC, you're essentially telling torch.export to be more flexible and not enforce strict size constraints during the export process. This allows the export to proceed without the ConstraintViolationError.

Analyzing the Error Message

Let's take a closer look at the error message to understand what's going on behind the scenes:

Traceback (most recent call last):
 File "/python3.11/site-packages/torch/export/_trace.py", line 1798, in _export_to_aten_ir_make_fx
 produce_guards_callback(gm)
 File "/python3.11/site-packages/torch/export/_trace.py", line 1944, in _produce_guards_callback
 return produce_guards_and_solve_constraints(
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/python3.11/site-packages/torch/_export/non_strict_utils.py", line 549, in produce_guards_and_solve_constraints
 raise constraint_violation_error
 File "/python3.11/site-packages/torch/_export/non_strict_utils.py", line 514, in produce_guards_and_solve_constraints
 shape_env.produce_guards(
 File "/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5200, in produce_guards
 return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5932, in produce_guards_verbose
 raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (t1)! For more information, run with TORCH_LOGS="+dynamic".
 - Not all values of t1 = L['input1'].size()[0] in the specified range t1 <= 1000 satisfy the generated guard ((8 + L['input1'].size()[0]) // 9) != 1.

The error above occurred when calling torch.export.export.

Key Parts of the Error

  • ConstraintViolationError: Constraints violated (t1)!: This tells you that the constraint associated with the dynamic dimension t1 is being violated.
  • Not all values of t1 = L['input1'].size()[0] in the specified range t1 <= 1000 satisfy the generated guard ((8 + L['input1'].size()[0]) // 9) != 1.: This is the most important part. It says that for the slicing operation input1[::9, :], the tool couldn't find a condition that holds true for all possible sizes of input1 along dimension 0 (up to 1000).
  • ((8 + L['input1'].size()[0]) // 9) != 1: This is the generated guard condition. It's a mathematical expression that torch.export uses to check if the slicing operation is valid. The // operator represents integer division.

Basically, PyTorch is doing its best to make sure that the slicing operation won't cause any problems when you actually run the exported model with different input sizes. But in this case, it's being overly cautious and throwing an error even though the slicing might be perfectly safe in practice.

When the Issue Disappears

It's worth noting that this issue seems to disappear under certain conditions:

  • Slicing a Static Dimension: If you're slicing along a dimension that's not dynamic, the error doesn't occur. For example, if you have a static dimension of size 41 and you slice it with a step, torch.export can easily determine the resulting size.
  • Dynamic Dimension with Dim.DYNAMIC: As we discussed earlier, using Dim.DYNAMIC instead of Dim("t", min=, max=) avoids the constraint violation.
  • Slicing with Step = 1: If you slice the dynamic dimension with a step size of 1 (e.g., input1[::1, :]), the issue doesn't arise. This is because the resulting size is simply the original size of the dynamic dimension.

Slicing at a static dimension

This issue disappears if you export the same model (e.g. input1[::9, :] but you slice at dim 0, while dim 0 is static and dim 1 is dynamic). For example:

import torch
from torch.export import export, save, Dim
import torch.nn as nn

class CustomModel(nn.Module):
 def __init__(self):
 super(CustomModel, self).__init__()
 def forward(self, input1):
 out1 = input1[:, ::2]
 return out1

model = CustomModel()
model.eval()
input1 = torch.randn(41, 6)

exported_model = torch.export.export(
 model,
 (input1,),
 dynamic_shapes={"input1": {1: Dim("t1", min=1, max=1000)},}
)

exportedmodel_filename = "exported_model_reduced.pt2"
save(exported_model, exportedmodel_filename)

In this case, the issue disappears because we slice at dim 1, while dim 0 is static and dim 1 is dynamic.

Conclusion

So, there you have it! A deep dive into a quirky torch.export bug that surfaces when you slice dynamic dimensions with a step size greater than 1 and use Dim("t", min=, max=). Remember, the workaround is to use Dim.DYNAMIC instead. While this might make the export process less strict, it allows you to sidestep the ConstraintViolationError and successfully export your model. I hope this helps you navigate the wild world of dynamic shapes in PyTorch!

For more information, you can check out the official PyTorch documentation on dynamic shapes and model exporting: https://pytorch.org/tutorials/intermediate/dynamic_quantization_mobile_tutorial.html

You may also like