151 lines
5.1 KiB
ReStructuredText
151 lines
5.1 KiB
ReStructuredText
.. _dsl_jit_caching:
|
|
.. |DSL| replace:: CuTe DSL
|
|
|
|
.. _JIT_Caching:
|
|
|
|
JIT Caching
|
|
====================
|
|
|
|
|
|
Zero Compile and JIT Executor
|
|
-----------------------------
|
|
|
|
Zero Compile is a feature that enables explicit kernel compilation on demand through ``cute.compile``.
|
|
When ``cute.compile`` is called, it compiles the kernel and returns a JIT Executor instance.
|
|
This JIT Executor instance can be cached and reused directly for subsequent executions without compiling the kernel again.
|
|
|
|
The JIT Executor is a component that independently executes compiled code.
|
|
It can be created either through ``cute.compile`` or implicit compilation.
|
|
The JIT Executor instance behaves like a callable object to execute the compiled code.
|
|
Each JIT Executor instance maintains a single compiled host function.
|
|
|
|
It encompasses all necessary execution components:
|
|
|
|
* Host function pointer and its MLIR execution engine
|
|
* CUDA modules (optional)
|
|
* Argument specifications defining how Python arguments are converted to C ABI-compatible types. Note that arguments with the ``cutlass.Constexpr`` hint are excluded from argument specifications since they are evaluated at compile time rather than runtime.
|
|
|
|
For example, in the following code, ``print_result`` is a ``cutlass.Constexpr`` value that is **NOT** evaluated at runtime:
|
|
|
|
.. code-block:: python
|
|
|
|
import cutlass.cute as cute
|
|
|
|
@cute.jit
|
|
def add(a, b, print_result: cutlass.Constexpr):
|
|
if print_result:
|
|
cute.printf("Result: %d\n", a + b)
|
|
return a + b
|
|
|
|
jit_executor = cute.compile(add, 1, 2, True)
|
|
|
|
jit_executor(1, 2) # output: ``Result: 3``
|
|
|
|
The JIT Executor ensures all components are properly initialized and loaded after compilation.
|
|
|
|
For example, all CUDA modules are loaded (via ``cuModuleLoad``) and kernel function pointers are extracted (via ``cuModuleGetFunction``).
|
|
|
|
When calling a JIT Executor instance, it:
|
|
|
|
* Parses Python runtime arguments and converts them to C ABI-compatible types according to argument specifications
|
|
* Invokes the host function with the converted arguments
|
|
|
|
Custom Caching with ``cute.compile``
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
``cute.compile`` bypasses caching in |DSL| and always performs compilation, returning a fixed JIT Executor instance.
|
|
This allows implementing custom caching strategies as shown below:
|
|
|
|
.. code-block:: python
|
|
|
|
@cute.jit
|
|
def add(b):
|
|
return a + b
|
|
|
|
# Define a custom cache
|
|
custom_cache = {}
|
|
|
|
a = 1
|
|
compiled_add_1 = cute.compile(add, 2)
|
|
custom_cache[1] = compiled_add_1
|
|
compiled_add_1(2) # result = 3
|
|
|
|
a = 2
|
|
compiled_add_2 = cute.compile(add, 2)
|
|
custom_cache[2] = compiled_add_2
|
|
compiled_add_2(2) # result = 4
|
|
|
|
# Use the custom cache
|
|
custom_cache[1](2) # result = 3
|
|
custom_cache[2](2) # result = 4
|
|
|
|
|
|
Cache in |DSL|
|
|
-----------------
|
|
|
|
By default, cache in |DSL| is implicitly enabled to avoid recompilation when kernels are called repeatedly without changes.
|
|
|
|
The cache is implemented as a map storing compiled JIT Executor instances within |DSL|.
|
|
|
|
The cache key combines hashes of:
|
|
|
|
* MLIR bytecode of the MLIR program generated by |DSL|
|
|
* All |DSL| Python source files
|
|
* All |DSL| shared libraries
|
|
* All |DSL| environment variables
|
|
|
|
The cache value is a compiled JIT Executor instance.
|
|
|
|
On a cache hit, compilation is skipped and the cached JIT Executor instance is reused.
|
|
|
|
On a cache miss, the kernel is compiled and the new JIT Executor instance is stored in the cache.
|
|
|
|
Here is an example demonstrating automatic caching of the ``add`` kernel:
|
|
|
|
.. code-block:: python
|
|
|
|
# Global variable
|
|
a = 1
|
|
|
|
@cute.jit
|
|
def add(b):
|
|
return a + b
|
|
|
|
# Cache is empty at beginning
|
|
|
|
# First call: cache miss triggers compilation
|
|
result = add(2) # result = 3
|
|
# Cache now has one instance
|
|
|
|
# Second call: cache hit reuses cached JIT Executor
|
|
result = add(2) # result = 3
|
|
|
|
a = 2
|
|
# Third call: cache miss due to changed IR code triggers recompilation
|
|
result = add(2) # result = 4
|
|
# Cache now has two instances
|
|
|
|
The cache can be serialized to files for subsequent runs.
|
|
After serialization, compiled MLIR bytecode is stored in file.
|
|
The cache directory is ``/tmp/{current_user}/cutlass_python_cache``.
|
|
The cache loads from files into memory during |DSL| initialization and saves back to files when the process exits.
|
|
|
|
The following environment variables control file caching:
|
|
|
|
.. code-block:: bash
|
|
|
|
# Disable file caching while keeping in-memory cache available, defaults to False.
|
|
export CUTE_DSL_DISABLE_FILE_CACHING=True
|
|
|
|
# Maximum number of cache files allowed, defaults to 1000.
|
|
export CUTE_DSL_FILE_CACHING_CAPACITY=1000
|
|
|
|
Limitations
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
The intention of caching is to reduce the host launch overhead before each execution. As above example shows,
|
|
the consistency between the original Python code and the MLIR program is hard to maintain because of the impact of dynamic factors such as global variables.
|
|
Therefore, the MLIR program **MUST** always be generated to verify that the kernel content matches what was previously built.
|
|
|
|
For optimal host launch latency, we recommend using above custom caching method with ``cute.compile``.
|