CUTLASS 2.9 (#468)
This commit is contained in:
796
tools/library/scripts/rt.py
Normal file
796
tools/library/scripts/rt.py
Normal file
@ -0,0 +1,796 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
# System imports
|
||||
import struct
|
||||
import io
|
||||
import ctypes
|
||||
|
||||
# CUDA Python import
|
||||
from cuda import cuda
|
||||
from cuda import nvrtc
|
||||
|
||||
# CUTLASS imports
|
||||
from library import *
|
||||
from gemm_operation import EmitGemmUniversalInstance
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# CUTLASS Py Runtime Components
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
def MaxAlignment(fmt):
|
||||
align = 1
|
||||
for x in fmt:
|
||||
align = max(align, struct.calcsize(x))
|
||||
return align
|
||||
|
||||
#
|
||||
def AlignedOffset(offset, align):
|
||||
remainder = (offset % align)
|
||||
if remainder:
|
||||
offset += (align - remainder)
|
||||
return offset
|
||||
|
||||
#
|
||||
def PackInteger(host_workspace, offset, value):
|
||||
fmt = "i"
|
||||
padding = AlignedOffset(offset, 4)
|
||||
struct.pack_into(fmt, host_workspace, offset, value)
|
||||
return padding + struct.calcsize(fmt)
|
||||
|
||||
#
|
||||
def PackDevicePointer(host_workspace, offset, value):
|
||||
fmt = "P"
|
||||
offset = AlignedOffset(offset, 8)
|
||||
struct.pack_into(fmt, host_workspace, offset, value)
|
||||
return offset + struct.calcsize(fmt)
|
||||
|
||||
#
|
||||
def ceil_div(a, b):
|
||||
return -(a // -b)
|
||||
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class PitchLinearCoord:
|
||||
def __init__(self, contiguous, strided):
|
||||
self.contiguous = contiguous
|
||||
self.strided = strided
|
||||
|
||||
#
|
||||
class GemmCoord:
|
||||
def __init__(self, m = 1, n = 1, k = 1):
|
||||
self.m = m
|
||||
self.n = n
|
||||
self.k = k
|
||||
self.fmt = "iii"
|
||||
|
||||
#
|
||||
def ceil_div(self, rhs):
|
||||
return GemmCoord(ceil_div(self.m, rhs.m), ceil_div(self.n, rhs.n), ceil_div(self.k, rhs.k))
|
||||
|
||||
#
|
||||
def size(self):
|
||||
return struct.calcsize(self.fmt)
|
||||
|
||||
#
|
||||
def alignment(self):
|
||||
return MaxAlignment(self.fmt)
|
||||
|
||||
#
|
||||
def pack_into(self, host_workspace, offset):
|
||||
|
||||
offset = AlignedOffset(offset, 4)
|
||||
|
||||
struct.pack_into(
|
||||
self.fmt,
|
||||
host_workspace,
|
||||
offset,
|
||||
self.m, self.n, self.k)
|
||||
|
||||
return offset + self.size()
|
||||
|
||||
#
|
||||
class TensorRef:
|
||||
def __init__(self, pointer = None, layout = 0):
|
||||
self.pointer = pointer
|
||||
self.layout = layout
|
||||
|
||||
def __str__(self):
|
||||
return "(%x, %d)" % (self.pointer._ptr, self.layout)
|
||||
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class PredicatedTileAccessIteratorDesc:
|
||||
'''
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
element_size_bits,
|
||||
advance_rank,
|
||||
threadblock_shape,
|
||||
threadmap_iterations,
|
||||
threadmap_delta):
|
||||
|
||||
self.element_size_bits = element_size_bits
|
||||
self.advance_rank = advance_rank
|
||||
self.threadblock_shape = threadblock_shape
|
||||
self.threadmap_iterations = threadmap_iterations
|
||||
self.threadmap_delta = threadmap_delta
|
||||
|
||||
#
|
||||
class PredicatedTileAccessIteratorParams:
|
||||
'''
|
||||
'''
|
||||
#
|
||||
def __init__(self, desc, label):
|
||||
self.desc = desc
|
||||
self.label = label
|
||||
self.fmt = "qqqq"
|
||||
#
|
||||
def size(self):
|
||||
return struct.calcsize(self.fmt)
|
||||
|
||||
#
|
||||
def alignment(self):
|
||||
return MaxAlignment(self.fmt)
|
||||
|
||||
#
|
||||
def initialize(self, host_workspace, offset, stride):
|
||||
|
||||
offset = AlignedOffset(offset, self.alignment())
|
||||
|
||||
inc_strided = stride * \
|
||||
self.desc.threadmap_delta.strided * \
|
||||
self.desc.element_size_bits // 8
|
||||
|
||||
if self.desc.advance_rank:
|
||||
inc_advance = self.desc.threadblock_shape.strided * \
|
||||
stride * \
|
||||
self.desc.element_size_bits // 8
|
||||
else:
|
||||
inc_advance = self.desc.threadblock_shape.contiguous * \
|
||||
self.desc.element_size_bits // 8
|
||||
|
||||
inc_next = inc_advance - (self.desc.threadmap_iterations.strided - 1) * \
|
||||
self.desc.threadmap_delta.strided * \
|
||||
stride * \
|
||||
self.desc.element_size_bits // 8
|
||||
|
||||
struct.pack_into(
|
||||
self.fmt,
|
||||
host_workspace,
|
||||
offset,
|
||||
stride, inc_strided, inc_next, inc_advance)
|
||||
|
||||
return offset + self.size()
|
||||
#
|
||||
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class EpilogueTileDesc:
|
||||
'''
|
||||
'''
|
||||
def __init__(self, column, row, group, cluster, tile):
|
||||
self.column = column
|
||||
self.row = row
|
||||
self.group = group
|
||||
self.cluster = cluster
|
||||
self.tile = tile
|
||||
|
||||
#
|
||||
class EpilogueThreadMap:
|
||||
'''
|
||||
'''
|
||||
def __init__(self, threads, elements_per_access, element_size_bits, shape, iterations, delta, count):
|
||||
self.threads = threads
|
||||
self.elements_per_access = elements_per_access
|
||||
self.element_size_bits = element_size_bits
|
||||
self.shape = shape
|
||||
self.iterations = iterations
|
||||
self.delta = delta
|
||||
self.count = count
|
||||
pass
|
||||
|
||||
#
|
||||
class EpilogueTileIteratorParams:
|
||||
'''
|
||||
'''
|
||||
#
|
||||
def __init__(self, desc, label):
|
||||
self.desc = desc
|
||||
self.label = label
|
||||
self.fmt = "qqqqqqqq"
|
||||
|
||||
#
|
||||
def size(self):
|
||||
return struct.calcsize(self.fmt)
|
||||
|
||||
#
|
||||
def alignment(self):
|
||||
return MaxAlignment(self.fmt)
|
||||
|
||||
#
|
||||
def initialize(self, host_workspace, offset, stride):
|
||||
|
||||
stride = stride * self.desc.element_size_bits // 8
|
||||
|
||||
offset = AlignedOffset(offset, self.alignment())
|
||||
|
||||
increment_row = stride * self.desc.delta.row
|
||||
|
||||
increment_group = stride * self.desc.delta.group \
|
||||
- stride * self.desc.delta.row * (self.desc.iterations.row - 1)
|
||||
|
||||
increment_cluster = stride * self.desc.delta.cluster \
|
||||
- stride * self.desc.delta.group * (self.desc.iterations.group - 1) \
|
||||
- stride * self.desc.delta.row * (self.desc.iterations.row - 1)
|
||||
|
||||
advance_row = stride * self.desc.shape.row
|
||||
|
||||
advance_group = stride * \
|
||||
(self.desc.shape.group - 1) * \
|
||||
self.desc.shape.row * \
|
||||
self.desc.count.row
|
||||
|
||||
advance_cluster = stride * \
|
||||
self.desc.count.group * \
|
||||
self.desc.shape.group * \
|
||||
self.desc.count.row * \
|
||||
self.desc.shape.row
|
||||
|
||||
advance_tile = stride * \
|
||||
self.desc.shape.group * \
|
||||
self.desc.shape.row * \
|
||||
self.desc.shape.cluster * \
|
||||
self.desc.shape.tile
|
||||
|
||||
struct.pack_into(
|
||||
self.fmt, \
|
||||
host_workspace, \
|
||||
offset, \
|
||||
stride, \
|
||||
increment_row, increment_group, increment_cluster, \
|
||||
advance_row, advance_group, advance_cluster, advance_tile)
|
||||
|
||||
return offset + self.size()
|
||||
#
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Launch configuration
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
class LaunchConfiguration:
|
||||
def __init__(self, grid = [1,1,1], block = [1,1,1], smem = 0):
|
||||
self.grid = grid
|
||||
self.block = block
|
||||
self.shared_memory_capacity = smem
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Functors
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class Functor:
|
||||
def __init__(self):
|
||||
self.decl = ''
|
||||
self.definition = ''
|
||||
self.fmt = ''
|
||||
self.identifier = ''
|
||||
|
||||
#
|
||||
def emit_declaration(self):
|
||||
return self.decl
|
||||
|
||||
#
|
||||
def emit_definition(self):
|
||||
return self.definition
|
||||
|
||||
#
|
||||
def size(self):
|
||||
'''
|
||||
Size of the packed Params structure
|
||||
'''
|
||||
return struct.calcsize(self.fmt)
|
||||
|
||||
#
|
||||
def alignment(self):
|
||||
return MaxAlignment(self.fmt)
|
||||
|
||||
#
|
||||
def initialize(self, host_workspace, offset, arguments):
|
||||
return offset + self.size()
|
||||
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class LinearCombinationFunctorArguments:
|
||||
def __init__(self, alpha = 1.0, beta = 0.0):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.alpha_ptr = 0
|
||||
self.beta_ptr = 0
|
||||
|
||||
#
|
||||
class LinearCombinationFunctor(Functor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.decl = """
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
float,
|
||||
1,
|
||||
float,
|
||||
float
|
||||
>"""
|
||||
self.identifier = 'linear_combination'
|
||||
self.fmt = "ffPP"
|
||||
|
||||
#
|
||||
def size(self):
|
||||
'''
|
||||
Size of the packed Params structure
|
||||
'''
|
||||
return struct.calcsize(self.fmt)
|
||||
|
||||
#
|
||||
def alignment(self):
|
||||
return MaxAlignment(self.fmt)
|
||||
|
||||
#
|
||||
def initialize(self, host_workspace, offset, arguments):
|
||||
|
||||
offset = AlignedOffset(offset, self.alignment())
|
||||
|
||||
struct.pack_into(
|
||||
self.fmt,
|
||||
host_workspace, offset,
|
||||
arguments.alpha, arguments.beta, arguments.alpha_ptr, arguments.beta_ptr)
|
||||
|
||||
return offset + self.size()
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Base class for an executable operation
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class ExecutableOperation:
|
||||
'''
|
||||
'''
|
||||
def __init__(self, operation):
|
||||
self.operation = operation
|
||||
self.module = None
|
||||
self.kernel = None
|
||||
|
||||
#
|
||||
def name(self):
|
||||
return self.operation.procedural_name()
|
||||
|
||||
#
|
||||
def emit(self):
|
||||
return ''
|
||||
|
||||
#
|
||||
def can_implement(self, configuration, arguments):
|
||||
return False
|
||||
|
||||
#
|
||||
def get_host_workspace_size(self, arguments):
|
||||
return 0
|
||||
|
||||
#
|
||||
def get_device_workspace_size(self, arguments):
|
||||
return 0
|
||||
|
||||
#
|
||||
def plan(self, arguments):
|
||||
return LaunchConfiguration()
|
||||
|
||||
#
|
||||
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream = cuda.CUstream(0)):
|
||||
raise NotImplementedError()
|
||||
|
||||
#
|
||||
def run(self, host_workspace, device_workspace, launch_config, stream = cuda.CUstream(0)):
|
||||
|
||||
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
|
||||
packed = (ctypes.c_void_p * 1)()
|
||||
packed[0] = ctypes.addressof(cArg)
|
||||
|
||||
err, = cuda.cuLaunchKernel(
|
||||
self.kernel,
|
||||
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
|
||||
launch_config.block[0], launch_config.block[1], launch_config.block[2],
|
||||
launch_config.shared_memory_capacity,
|
||||
stream,
|
||||
packed,
|
||||
0)
|
||||
|
||||
return err
|
||||
|
||||
#################################################################################################
|
||||
|
||||
|
||||
#
|
||||
class GemmArguments:
|
||||
'''
|
||||
'''
|
||||
def __init__(self):
|
||||
self.problem_size = GemmCoord(0, 0, 0)
|
||||
self.A = TensorRef()
|
||||
self.B = TensorRef()
|
||||
self.C = TensorRef()
|
||||
self.D = TensorRef()
|
||||
self.output_op = LinearCombinationFunctorArguments()
|
||||
|
||||
#
|
||||
class ThreadblockSwizzle:
|
||||
def __init__(self, threadblock_shape, log_threadblock_cohort = 0):
|
||||
self.threadblock_shape = threadblock_shape
|
||||
self.log_threadblock_cohort = log_threadblock_cohort
|
||||
|
||||
def grid_tiled_shape(self, problem_size):
|
||||
return GemmCoord(
|
||||
ceil_div(problem_size.m, self.threadblock_shape.m),
|
||||
ceil_div(problem_size.n, self.threadblock_shape.n),
|
||||
1)
|
||||
|
||||
#
|
||||
class Gemm(ExecutableOperation):
|
||||
'''
|
||||
GEMM manages the CUTLASS runtime components
|
||||
'''
|
||||
#
|
||||
def __init__(self, operation):
|
||||
super().__init__(operation)
|
||||
|
||||
self.emitter = EmitGemmUniversalInstance('_type')
|
||||
self.threadblock_swizzle = ThreadblockSwizzle(GemmCoord(128, 128, 8))
|
||||
|
||||
self.threads = 256
|
||||
self.shared_memory_capacity = (32 << 10)
|
||||
|
||||
self.params_A = PredicatedTileAccessIteratorParams(
|
||||
PredicatedTileAccessIteratorDesc(
|
||||
32,
|
||||
1,
|
||||
PitchLinearCoord(128, 8),
|
||||
PitchLinearCoord(1, 4),
|
||||
PitchLinearCoord(1, 2)), 'A')
|
||||
|
||||
self.params_B = PredicatedTileAccessIteratorParams(
|
||||
PredicatedTileAccessIteratorDesc(
|
||||
32,
|
||||
1,
|
||||
PitchLinearCoord(128, 8),
|
||||
PitchLinearCoord(1, 4),
|
||||
PitchLinearCoord(1, 2)), 'B')
|
||||
|
||||
self.params_C = EpilogueTileIteratorParams(
|
||||
EpilogueThreadMap(
|
||||
256,
|
||||
1,
|
||||
32,
|
||||
EpilogueTileDesc(128, 1, 4, 4, 1),
|
||||
EpilogueTileDesc(4, 1, 2, 1, 1),
|
||||
EpilogueTileDesc(32, 1, 8, 1, 1),
|
||||
EpilogueTileDesc(1, 4, 2, 1, 8)), 'C')
|
||||
|
||||
self.params_D = EpilogueTileIteratorParams(
|
||||
EpilogueThreadMap(
|
||||
256,
|
||||
1,
|
||||
32,
|
||||
EpilogueTileDesc(128, 1, 4, 4, 1),
|
||||
EpilogueTileDesc(4, 1, 2, 1, 1),
|
||||
EpilogueTileDesc(32, 1, 8, 1, 1),
|
||||
EpilogueTileDesc(1, 4, 2, 1, 8)), 'D')
|
||||
|
||||
self.output_op = LinearCombinationFunctor()
|
||||
|
||||
#
|
||||
def emit(self):
|
||||
return self.emitter.emit(self.operation)
|
||||
|
||||
#
|
||||
def can_implement(self, configuration, arguments):
|
||||
pass
|
||||
|
||||
#
|
||||
def get_host_workspace_size(self, arguments):
|
||||
return 336
|
||||
|
||||
#
|
||||
def get_device_workspace_size(self, arguments):
|
||||
return 0
|
||||
|
||||
#
|
||||
def plan(self, arguments):
|
||||
grid = self.threadblock_swizzle.grid_tiled_shape(arguments.problem_size)
|
||||
return LaunchConfiguration([grid.m, grid.n, grid.k], [self.threads, 1, 1], self.shared_memory_capacity)
|
||||
|
||||
#
|
||||
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream = cuda.CUstream(0)):
|
||||
|
||||
offset = 0
|
||||
|
||||
# Compute intermediate results
|
||||
swizzle_log_tile = 0
|
||||
gemm_mode = 0
|
||||
batch_count = 1
|
||||
gemm_k_size = arguments.problem_size.k
|
||||
|
||||
# Pack into the host workspace buffer
|
||||
offset = arguments.problem_size.pack_into(host_workspace, offset)
|
||||
|
||||
grid_tiled_shape = self.threadblock_swizzle.grid_tiled_shape(arguments.problem_size)
|
||||
offset = grid_tiled_shape.pack_into(host_workspace, offset)
|
||||
|
||||
offset = PackInteger(host_workspace, offset, swizzle_log_tile)
|
||||
|
||||
offset = self.params_A.initialize(host_workspace, offset, arguments.A.layout)
|
||||
offset = self.params_B.initialize(host_workspace, offset, arguments.B.layout)
|
||||
offset = self.params_C.initialize(host_workspace, offset, arguments.C.layout)
|
||||
offset = self.params_D.initialize(host_workspace, offset, arguments.D.layout)
|
||||
|
||||
offset = self.output_op.initialize(host_workspace, offset, arguments.output_op)
|
||||
|
||||
offset = PackInteger(host_workspace, offset, gemm_mode)
|
||||
offset = PackInteger(host_workspace, offset, batch_count)
|
||||
offset = PackInteger(host_workspace, offset, gemm_k_size)
|
||||
offset = PackDevicePointer(host_workspace, offset, int(arguments.A.pointer))
|
||||
offset = PackDevicePointer(host_workspace, offset, int(arguments.B.pointer))
|
||||
offset = PackDevicePointer(host_workspace, offset, int(arguments.C.pointer))
|
||||
offset = PackDevicePointer(host_workspace, offset, int(arguments.D.pointer))
|
||||
|
||||
return offset
|
||||
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Module represents a compilation unit
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class CompilationOptions:
|
||||
'''
|
||||
Compilation options.
|
||||
'''
|
||||
|
||||
#
|
||||
def __init__(self, architectures = [80], include_paths = []):
|
||||
self.includes = []
|
||||
self.include_paths = include_paths
|
||||
self.flags = ['-std=c++11', '-default-device']
|
||||
self.architectures = architectures
|
||||
|
||||
#
|
||||
def get(self):
|
||||
options = []
|
||||
|
||||
for flag in self.flags:
|
||||
options.append(bytes(str.encode(flag)))
|
||||
|
||||
for incl in self.include_paths:
|
||||
options.append(bytes(str.encode('--include-path=%s' % incl)))
|
||||
|
||||
arch_list = "-arch="
|
||||
for idx, arch in enumerate(self.architectures):
|
||||
if idx:
|
||||
arch_list += ","
|
||||
arch_list += "sm_%d" % arch
|
||||
|
||||
options.append(bytes(str.encode(arch_list)))
|
||||
|
||||
return options
|
||||
|
||||
IncludeTemplate = r'''#include "${include}"
|
||||
'''
|
||||
|
||||
KernelTemplate = r'''
|
||||
extern "C"
|
||||
__global__ void
|
||||
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
||||
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int SharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
||||
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
||||
|
||||
${operation_name}${operation_suffix} op;
|
||||
|
||||
op(params, *shared_storage);
|
||||
}
|
||||
|
||||
'''
|
||||
|
||||
#
|
||||
class Module:
|
||||
def __init__(self, name, operations, compilation_options):
|
||||
self.name = name
|
||||
self.operations = operations
|
||||
self.module = None
|
||||
self.log = None
|
||||
self.cubin_image = None
|
||||
self.source_buffer = ''
|
||||
|
||||
#
|
||||
# Emit source
|
||||
#
|
||||
self.emit_()
|
||||
|
||||
#
|
||||
# Compile
|
||||
#
|
||||
self.compile_(compilation_options)
|
||||
|
||||
#
|
||||
# Load module
|
||||
#
|
||||
self.load_()
|
||||
|
||||
# Done
|
||||
return
|
||||
|
||||
# Emit a source buffer
|
||||
def emit_(self):
|
||||
|
||||
# 1. Includes
|
||||
includes = []
|
||||
for operation in self.operations:
|
||||
for incl in operation.emitter.includes:
|
||||
if incl not in includes:
|
||||
includes.append(incl)
|
||||
|
||||
for incl in includes:
|
||||
self.source_buffer += SubstituteTemplate(IncludeTemplate, { 'include': incl} )
|
||||
|
||||
# 2. Operations
|
||||
for operation in self.operations:
|
||||
self.source_buffer += operation.emit()
|
||||
values = {
|
||||
'operation_name': operation.name(),
|
||||
'operation_suffix': operation.emitter.operation_suffix
|
||||
}
|
||||
self.source_buffer += SubstituteTemplate(KernelTemplate, values)
|
||||
|
||||
# Done
|
||||
return
|
||||
|
||||
# Compile with NVRTC
|
||||
def compile_(self, compilation_options):
|
||||
|
||||
err, program = nvrtc.nvrtcCreateProgram(
|
||||
str.encode(self.source_buffer),
|
||||
bytes(str.encode(self.name)),
|
||||
0, [], [])
|
||||
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
# Compile program
|
||||
options = compilation_options.get()
|
||||
|
||||
err, = nvrtc.nvrtcCompileProgram(program, len(options), options)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
|
||||
error_string = 'NVRTC Error: {}\n'.format(err)
|
||||
|
||||
# Get log from compilation
|
||||
err, logSize = nvrtc.nvrtcGetProgramLogSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
self.log = b' ' * logSize
|
||||
err, = nvrtc.nvrtcGetProgramLog(program, self.log)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
raise RuntimeError(error_string + self.log.decode() + self.source_buffer)
|
||||
|
||||
# Get data from compilation
|
||||
err, dataSize = nvrtc.nvrtcGetCUBINSize(program)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
self.cubin_image = b' ' * dataSize
|
||||
err, = nvrtc.nvrtcGetCUBIN(program, self.cubin_image)
|
||||
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise RuntimeError('NVRTC Error: {}'.format(err))
|
||||
|
||||
return
|
||||
|
||||
#
|
||||
def load_(self):
|
||||
|
||||
# Load data as module data
|
||||
err, self.module = cuda.cuModuleLoadData(self.cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Cuda Error: {}'.format(err))
|
||||
|
||||
# Get functions
|
||||
for operation in self.operations:
|
||||
err, operation.kernel = cuda.cuModuleGetFunction(
|
||||
self.module,
|
||||
bytes(str.encode(operation.name())))
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Cuda Error: {}'.format(err))
|
||||
|
||||
operation.module = self
|
||||
|
||||
return
|
||||
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Manifest represents an 'owner' for modules and operations
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
class Manifest:
|
||||
|
||||
#
|
||||
def __init__(self):
|
||||
self.operations = {}
|
||||
self.modules = []
|
||||
pass
|
||||
|
||||
#
|
||||
def append_module(self, module):
|
||||
'''
|
||||
Appends a module and takes ownership of operations used to construct it.
|
||||
'''
|
||||
|
||||
self.modules.append(module)
|
||||
|
||||
for operation in module.operations:
|
||||
self.operations[operation.name()] = operation
|
||||
|
||||
|
||||
#################################################################################################
|
||||
Reference in New Issue
Block a user