Skip to content

Support and Behavior of JIT Compilation for Dynamic Shapes in XLA #32619

@diligentliu

Description

@diligentliu

Hi, I have a question regarding dynamic shape support in TensorFlow + XLA.

When I use tf.function(jit_compile=True) with TensorSpecs that have partially unknown dimensions (e.g., [None, 4096]), I noticed that XLA seems to generate the corresponding compiled kernel at runtime based on the actual input shapes.

From my understanding, for dynamic shapes, XLA does not precompile a fully general kernel. Instead, it relies on just-in-time (JIT) compilation to produce a concrete CUDA kernel once it receives concrete input shapes. This appears to be different from frameworks like TVM, where the graph is first optimized statically and then executed; in TVM, optimizations can be applied based on known static shapes before runtime execution, potentially enabling more aggressive operator fusion and memory planning.

Could you please confirm if my understanding is correct? Specifically:

  1. Is dynamic shape support in XLA primarily implemented via this runtime JIT compilation approach?

  2. Is there any ongoing effort or recommended approach to allow more static-graph-like optimizations for partially dynamic shapes, similar to how TVM can optimize a graph based on partially known shapes before execution?

I am particularly interested in understanding whether XLA has mechanisms to reduce runtime compilation overhead for workloads with dynamic batch sizes or other partially unknown dimensions.

Metadata

Metadata

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions