Torch.export Error: Dynamic Slice With Dim And Step > 1
Introduction
Hey guys! Today, we're diving deep into a tricky bug encountered while using torch.export in PyTorch. Specifically, this issue arises when you're trying to export a model that slices an input tensor with dynamic shapes along a dynamic dimension, and you're using Dim("t", min=, max=) to define the dynamic shape. Let's break down the problem and see how we can work around it.
The Bug: Slicing Dynamic Dimensions with Steps
The core of the problem lies in how torch.export handles slicing operations on tensors with dynamic shapes, especially when a step size greater than 1 is involved. When you define a dynamic dimension using Dim("t", min=, max=), and then attempt to slice that dimension with a step (e.g., input1[::9, :]), torch.export might throw a ConstraintViolationError. This error essentially means that the tool can't guarantee that all possible values within the specified dynamic range will satisfy the constraints imposed by the slicing operation.
Code Example
To illustrate this, consider the following code snippet:
import torch
from torch.export import export, save, Dim
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
def forward(self, input1):
out1 = input1[::9, :]
return out1
model = CustomModel()
model.eval()
input1 = torch.randn(41, 6)
exported_model = torch.export.export(
model,
(input1,),
dynamic_shapes={"input1": {0: Dim("t1", min=1, max=1000)},}
)
exportedmodel_filename = "exported_model_reduced.pt2"
save(exported_model, exportedmodel_filename)
When you run this code, you'll likely encounter the dreaded ConstraintViolationError. The error message will tell you that not all values of t1 (the dynamic dimension) within the range 1 <= t1 <= 1000 satisfy the generated guard condition ((8 + L['input1'].size()[0]) // 9) != 1.
Why Does This Happen?
The reason for this error is that torch.export tries to create guard conditions to ensure that the slicing operation is valid for all possible shapes within the dynamic range. When you use a step size greater than 1, the resulting shape after slicing depends on the exact value of the dynamic dimension. The tool struggles to create a single, universal guard condition that works for all values in the specified range. In simpler terms, PyTorch is trying to figure out if the slice input1[::9, :] will always produce a valid tensor, no matter the size of the input.
Workaround: Using Dim.DYNAMIC
Fortunately, there's a simple workaround for this issue. Instead of using Dim("t", min=, max=) to define the dynamic shape, you can use Dim.DYNAMIC. This tells torch.export that the dimension is dynamic, but it doesn't impose any specific minimum or maximum values. This approach seems to sidestep the constraint violation issue.
Updated Code Example
Here's how you can modify the code to use Dim.DYNAMIC:
import torch
from torch.export import export, save, Dim
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
def forward(self, input1):
out1 = input1[::9, :]
return out1
model = CustomModel()
model.eval()
input1 = torch.randn(41, 6)
exported_model = torch.export.export(
model,
(input1,),
dynamic_shapes={"input1": {0: Dim.DYNAMIC},}
)
exportedmodel_filename = "exported_model_reduced.pt2"
save(exported_model, exportedmodel_filename)
By using Dim.DYNAMIC, you're essentially telling torch.export to be more flexible and not enforce strict size constraints during the export process. This allows the export to proceed without the ConstraintViolationError.
Analyzing the Error Message
Let's take a closer look at the error message to understand what's going on behind the scenes:
Traceback (most recent call last):
File "/python3.11/site-packages/torch/export/_trace.py", line 1798, in _export_to_aten_ir_make_fx
produce_guards_callback(gm)
File "/python3.11/site-packages/torch/export/_trace.py", line 1944, in _produce_guards_callback
return produce_guards_and_solve_constraints(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/python3.11/site-packages/torch/_export/non_strict_utils.py", line 549, in produce_guards_and_solve_constraints
raise constraint_violation_error
File "/python3.11/site-packages/torch/_export/non_strict_utils.py", line 514, in produce_guards_and_solve_constraints
shape_env.produce_guards(
File "/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5200, in produce_guards
return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5932, in produce_guards_verbose
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (t1)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of t1 = L['input1'].size()[0] in the specified range t1 <= 1000 satisfy the generated guard ((8 + L['input1'].size()[0]) // 9) != 1.
The error above occurred when calling torch.export.export.
Key Parts of the Error
ConstraintViolationError: Constraints violated (t1)!: This tells you that the constraint associated with the dynamic dimensiont1is being violated.Not all values of t1 = L['input1'].size()[0] in the specified range t1 <= 1000 satisfy the generated guard ((8 + L['input1'].size()[0]) // 9) != 1.: This is the most important part. It says that for the slicing operationinput1[::9, :], the tool couldn't find a condition that holds true for all possible sizes ofinput1along dimension 0 (up to 1000).((8 + L['input1'].size()[0]) // 9) != 1: This is the generated guard condition. It's a mathematical expression thattorch.exportuses to check if the slicing operation is valid. The//operator represents integer division.
Basically, PyTorch is doing its best to make sure that the slicing operation won't cause any problems when you actually run the exported model with different input sizes. But in this case, it's being overly cautious and throwing an error even though the slicing might be perfectly safe in practice.
When the Issue Disappears
It's worth noting that this issue seems to disappear under certain conditions:
- Slicing a Static Dimension: If you're slicing along a dimension that's not dynamic, the error doesn't occur. For example, if you have a static dimension of size 41 and you slice it with a step,
torch.exportcan easily determine the resulting size. - Dynamic Dimension with
Dim.DYNAMIC: As we discussed earlier, usingDim.DYNAMICinstead ofDim("t", min=, max=)avoids the constraint violation. - Slicing with Step = 1: If you slice the dynamic dimension with a step size of 1 (e.g.,
input1[::1, :]), the issue doesn't arise. This is because the resulting size is simply the original size of the dynamic dimension.
Slicing at a static dimension
This issue disappears if you export the same model (e.g. input1[::9, :] but you slice at dim 0, while dim 0 is static and dim 1 is dynamic). For example:
import torch
from torch.export import export, save, Dim
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
def forward(self, input1):
out1 = input1[:, ::2]
return out1
model = CustomModel()
model.eval()
input1 = torch.randn(41, 6)
exported_model = torch.export.export(
model,
(input1,),
dynamic_shapes={"input1": {1: Dim("t1", min=1, max=1000)},}
)
exportedmodel_filename = "exported_model_reduced.pt2"
save(exported_model, exportedmodel_filename)
In this case, the issue disappears because we slice at dim 1, while dim 0 is static and dim 1 is dynamic.
Conclusion
So, there you have it! A deep dive into a quirky torch.export bug that surfaces when you slice dynamic dimensions with a step size greater than 1 and use Dim("t", min=, max=). Remember, the workaround is to use Dim.DYNAMIC instead. While this might make the export process less strict, it allows you to sidestep the ConstraintViolationError and successfully export your model. I hope this helps you navigate the wild world of dynamic shapes in PyTorch!
For more information, you can check out the official PyTorch documentation on dynamic shapes and model exporting: https://pytorch.org/tutorials/intermediate/dynamic_quantization_mobile_tutorial.html