FuriosaAI

FuriosaAI and OpenAI showcase the future of sustainable enterprise AI. Read the announcement.

How PyTorch handles dynamic tensor shapes

Our Viewpoints

Written by Nuno Lopes

Share this article

Fai infgfx Py Torch graph 1 updated

AI models often process data that changes in size from one request to the next (a user prompt, a batch of images, or a video clip). This variability creates dynamic tensor shapes, a major challenge for performance optimization. In this post, we’ll explore how the PyTorch compiler tackles this challenge, how you can harness its power, and pitfalls to avoid.

What are dynamic tensor shapes?

All tensors have an associated shape that specifies their dimensions. For example, a matrix with m rows and n columns is a tensor with shape (m, n) and rank 2.

In PyTorch, tensors have additional properties, such as data type (e.g., INT8, FP32) and memory layout (given by the strides and contiguity properties). Together, these properties allow PyTorch to make sense of the bytes stored in memory.

In AI models, weights are usually of fixed size: they are matrices containing parameters learned during training, but the number of parameters is typically fixed (unless you are doing architecture search). This is what we call static shapes: the dimensions of the weight tensors are fixed.

By contrast, dynamic shapes refer to tensors whose dimensions can change during execution. For example, the batch size can change between calls (you process 64 images now, but only 32 on the following batch). For language models, the context length can vary significantly between prompts, hence it is dynamic, while the vocabulary size is fixed.

Why do we need compiler support for dynamic shapes?

In theory, we could survive without compiler support for dynamic shapes by simply recompiling code for every new shape. However, this would be extremely inefficient. Proper compiler support is what makes dynamic shapes practical and performant, and we should know how to take advantage of it!

One of the main features of PyTorch 2 is the new compiler (torch.compile), designed to improve performance without sacrificing the flexibility and ease of use that characterizes PyTorch. If you have used Python for a while, you may have had flashbacks about TensorFlow when torch.compile was first announced, but fear no more: PyTorch is still as usable as before. To understand where dynamic shapes fit into this puzzle, let’s start by looking at how PyTorch executes programs.

Fai infgfx Py Torch graph 2a2b

PyTorch's default execution model is eager mode, where the program is executed one operation at a time. See the following example:

        x = torch.randn(5)
y = x + x
z = y * 5
print(z)
      

The program’s execution essentially goes line-by-line, executing one operation at a time. When executing this code on the CPU, that might be just okay. But if you are running the code on an AI accelerator (such as a GPU or Furiosa’s RNGD), the overhead of going back-and-forth between the CPU and the accelerator is a waste of time and energy.

PyTorch implements an optimization to hide this overhead: most operations execute asynchronously and thus return immediately. Only some operations, such as I/O (like the print above), block execution as they need to wait until their operands are computed. This technique allows the Python code to be executed on the CPU, while the accelerator works in parallel. This optimization was probably what made PyTorch successful: the eager mode was easier to use than other frameworks at the time, and the async execution model delivered decent performance.

The async execution model is great in hiding overhead, but the overhead is still there. Energy is wasted and performance is left on the table. Given that AI deployments are now using millions of devices, even a small percentage of speed-up covers the salary of those implementing it!

This is where PyTorch 2’s compiler comes in. The goal was to execute PyTorch programs mostly unmodified while improving performance. The key for executing AI models efficiently is to have the whole model (or big chunks of it) in a compiler-friendly representation (like a dataflow graph). PyTorch’s compiler does a sort of symbolic execution of the code to obtain the dataflow graph(s) of the model. It starts by reading the Python bytecode of the initial function and then follows the called functions until it gets the whole model or until it can’t trace the code anymore (more on this later).

Here is the bytecode of a simple Python function:

        def simple_fn(x):
    return x * 2 + 3
      

And here is the corresponding (stack-based) Python bytecode:

        LOAD_FAST x     ; push x onto the stack
LOAD_CONST 2    ; push 2 onto the stack
BINARY_OP 5     ; pop the last two values on the stack, multiply them,
                ; and push the result: x * 2
LOAD_CONST 3    ; push 3 onto the stack
BINARY_OP 0     ; pop the last two values, add them,
                ; and push the result: (x * 2) + 3
RETURN_VALUE    ; return the last value on the stack
      

The bytecode is fairly straightforward, but nowhere does it mention PyTorch. Variable x may be a PyTorch tensor, but it can also be a mere integer or even a string. To figure out the types of the variables, the symbolic execution engine runs the code using “fake tensors,” which are tensors that do not contain any data. These tensors are used by the compiler to execute the code quickly and without allocating memory, while recording the operations and the tensor shapes involved.

You can trigger compilation as follows:

        compiled_fn = torch.compile(simple_fn)

# First execution: triggers compilation
y1 = compiled_fn(torch.randn(8))

# Second execution: compiled code is cached
y2 = compiled_fn(torch.randn(8))
      

Since compilation is only triggered on the first execution (and not when torch.compile is called), the compiler knows which arguments are tensors and which are not. The ones that are tensors are replaced with fake tensors, and then the code is executed, producing an FX graph (a compiler-friendly representation):

        class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[8]"):
        l_x_ = L_x_
        mul: "f32[8]" = l_x_ * 2;  l_x_ = None
        add: "f32[8]" = mul + 3;  mul = None
        return (add,)
      


The FX graph is then optimized and may be compiled into a single CUDA kernel if you are using an Nvidia GPU, for example.

What’s the issue? Notice that the graph recorded the specific shape ([8]) and datatype (f32) of the tensors. What happens if we now call the compiled function with a tensor with a different shape? Maybe the compiler made some optimization that relied on the dimension being a power of two? Or maybe it did not, but so far, we have no way to tell.

(Bonus question: can you figure out why the FX graph above has all those assignments of None? Tip: Remember that Python is a reference-counted language.*)

Avoiding recompilation: dynamic shapes to the rescue

The naive way to handle tensors with different shapes is to recompile the code for each combination. The problem is that techniques such as continuous batching, paged attention, etc., as well as very different prompt sizes from users can easily produce thousands of different shape combinations. Compiling a program for each shape is usually not an option. Not only does compilation take a long time (that’s a topic for another post), but compiled programs take up memory (tens of megabytes, at least). The more RAM you spend on program binaries, the less space you have left for weights, KV-cache, etc.

PyTorch 2 takes a pragmatic approach: the first time it compiles a model, it assumes all shapes are static. If the model gets called later with at least one input with a different shape, the code is recompiled, but now those tensors are marked as having a dynamic shape. The remaining tensors remain with a static shape.

Fai infgfx Py Torch graph 3a
Fai infgfx Py Torch graph 3b
Fai infgfx Py Torch graph 3c

Let’s go back to our example, and let’s call the compiled function twice with different shapes:

        # First execution: triggers compilation with static shapes
y1 = compiled_fn(torch.randn(3))


# Second execution: triggers compilation with x having dynamic shape
y2 = compiled_fn(torch.randn(5))
      

When a model is compiled, tensors with dynamic shapes track the shape constraints symbolically instead of tracking the concrete shape. For example, if we add two tensors (x + y), the shapes of x and y must be the same, or one of them must be broadcastable to match the other’s shape. This condition is sufficient to describe the precondition for executing a tensor addition; we didn’t need to mention their concrete shapes.

PyTorch’s compiler uses “fake tensors” to execute the model’s code and record these symbolic constraints on dynamic shapes. In the end, we get a gigantic logic formula that represents the precondition that needs to be met to execute the code. PyTorch uses SymPy to simplify this formula and then generates a C++ function that tests whether the condition holds. Producing the minimal formula is undecidable in general, so SymPy does the best it can.

For our example, the generated C++ code for checking the guard (PyTorch’s name for the precondition) is:

        int8_t guard(int64_t *int_values, double *float_values) {
    int64_t L_x_size_0_ = int_values[0];
    return (2L <= L_x_size_0_);
}
      

So the only condition we have is that the first dimension of x must be greater or equal to 2. This is because most arithmetic operations in PyTorch behave differently when a dimension is 1: they have to perform a broadcast to match the other operand’s shape.

Note that this example is very simple, and we run it on the CPU, hence the compiler did not specialize the code for particular shapes.

Let’s look at a more interesting example. Assume we are performing a matrix multiplication between a tensor x of shape (m, k) and a tensor y of shape (k, n):

        def simple_fn(x, y):
    return torch.matmul(x, y)
      

If we call this function twice with tensors that have different dimensions, the compiler generates the following guard that enforces the shape rules of matrix multiplication:

        int8_t guard(int64_t *int_values, double *float_values) {
    int64_t L_y_size_0_ = int_values[0], L_x_size_0_ = int_values[1];
    return (L_y_size_0_ == L_x_size_0_) &&
           ((2L <= L_x_size_0_) & (L_x_size_0_ <= 4194304L));
}
      

As expected, matmul requires that the contracting dimension of the two tensors be equal (for matrices, the number of columns in x must be equal to the number of rows in y). The compiler also decided to limit the code for cases where the contracting dimension is between 2 and 222. Note that if you call the function with the same tensor in both arguments, the compiler specializes the code for x and y being equal.

Compiling with dynamic shapes

Compiling programs with dynamic shapes is still in the research domain. It’s not a trivial task, since there’s usually an optimal compilation for each shape. Surely some shapes can be bucketized, but the question remains: what’s the strategy to bucketize shapes automatically?

In the past, it was common to pad tensors so they would fit a predetermined bucket size or boundary. For example, we may have a kernel that assumes that every dimension is a multiple of 8 as it performs operations over vectors of that size. Padding may be okay if done in small amounts, and even better if the chip has a fast path for operations with zero.

Some algorithms can handle multiple shapes without padding. They can have a fast vectorized loop to handle most of the tensor, and then a slow loop to handle the leftover bits (running at most 7 iterations in our example of tensors needing to be aligned at a multiple of 8).

The advantage of these techniques is that they widen the range of shapes each compiled program can handle. Although they increase the code size of every kernel, we may end up with less code overall because a single program may be sufficient.

In the world of traditional compilers (for, say, C++), a lot of projects got good speed-ups when using profile-guided optimizations. It probably makes sense to do the same with shapes: we can measure the common shapes and make sure we optimize for those cases. The rare cases can even be run in eager mode; it may not be worth compiling them.

Advanced tricks

This post was a brief introduction to how dynamic shapes work in PyTorch 2. Here are a few pointers if you want to explore the topic further.

Perhaps the first thing to try is to enable all debug logging of the compiler with:

        import logging
torch._logging.set_logs(all=logging.DEBUG)
      

This will make PyTorch produce a lot of information. It shows you how many times the model has been compiled, the captured FX graph, the guard of each compiled program, etc. It’s a good place to start learning about the internals of the compiler.

This debug output is also a great place to look for clues about things that can be optimized. You can see how complex the guard is and the evaluation latency (if this evaluation becomes a bottleneck, and you can guarantee that the input shapes will always satisfy the conditions, you can disable the guard check for maximum performance). Looking at the guard will also give you an indication if the compiler is leveraging the dimensions to produce optimized code or not. You can also explore using torch._dynamo.mark_dynamic to tell the compiler upfront that a tensor has dynamic shape and even give boundaries for the tensor dimensions. You can also inspect the kernels that the compiler produces and see if there’s room left for improvement.

Another important thing is to keep an eye on graph breaks. Sometimes PyTorch can’t capture the whole model into a single dataflow graph. This happens if you have some operation that requires materializing an intermediate value. For example, printing a value forces a graph break: you’ll have one graph until the print, and a second graph for anything after the print. Obviously, you want to avoid graph breaks whenever possible. You can use torch.compile(model, fullgraph=True) to ensure that you will always have a single graph (otherwise, the compiler aborts).

Finally, PyTorch has multiple compiler backends. They have different tradeoffs in terms of compilation time, memory usage, run-time performance, etc. You should try all the backends that apply to your chip and learn their characteristics and quirks.

If you want to learn more, read the official documentation on torch.compile and Dynamic Shapes.

Looking forward

The compiler community is still learning how to best compile AI models. Although AI models are very small programs (compared to, say, web browsers that have tens of millions of lines of code!), squeezing all the performance of hardware accelerators is not easy. Some accelerators make the compiler job easier than others (history is full of great chips on paper for which no one could write a compiler for; none has survived!), but it’s never a trivial task.

We expect there will soon be further research on compiling programs with dynamic shapes. We need a feedback loop between the compiler and the execution engine. The compiler must be able to state under which conditions the code is valid (generate the weakest precondition), as well as the conditions where the code has excellent performance (likely tighter than the precondition) and even to specify the performance gap for shapes outside the optimal range. The execution engine can then decide when to recompile and when to reuse existing code based on run-time statistics.

Besides compilation, we believe that large inference deployments can sidestep the issue a bit by routing user prompts with similar sequence lengths to the same device, hence having devices specialized for particular intervals of sequence lengths. The orchestration system will decide in real time which specialization each device should be running based on the incoming and predicted traffic.

Fai infgfx Py Torch graph 4

*The Torch FX representation is a Python program that can be executed as-is (i.e., it is interpreted by Python). In Python, a variable is deallocated only when there are no more references to it. Even if a variable is never accessed again, having a reference to it precludes Python from deleting it. Since tensors can be big objects (sometimes taking GBs each), deallocating them ASAP is very important, not only to prevent running out of memory, but also for performance reasons (such as reusing memory blocks already in the cache). PyTorch computes the liveness ranges of tensors and adds those assignments of None when the liveness of a variable ends, so it is deleted by Python ASAP. Having the results of liveness analysis also enables further optimizations, like performing operations in-place (reusing the memory of one of the operands instead of allocating memory for the result).

Written by Nuno Lopes

Share this article

Get the latest updates on FuriosaAI