TensorRT example

practice

TensorRT examples

Examples of exporting torch model to TensorRT.

Dynamic shapes

Prepare running env:

pip install cuda-python
pip install tensorrt==10.3.0

Example of TensorRT conversion with inputs and outputs of dynamic shapes.

import os
import torch

from pathlib import Path
from copy import deepcopy

from functools import reduce
import onnx
import onnxruntime
import tensorrt as trt
from cuda import cudart


def compute_rms_norm(x, dim, eps):
    dtype = reduce(torch.promote_types, (x.dtype, torch.float32))
    mean_sq = x.to(dtype).square().mean(dim=dim, keepdim=True)
    return (mean_sq + eps).rsqrt().to(x.dtype)


class TopK(torch.nn.Module):
    def __init__(self, k: int, dim: int, eps: float = 1e-5) -> None:
        super().__init__()
        self.k = k
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        v = compute_rms_norm(x, -1, self.eps)
        values, indices = torch.topk(v, self.k, dim=self.dim)
        return v, indices


class SlicedLinearModel(torch.nn.Module):
    def __init__(self, input_size, output_size, topk, dim):
        super(SlicedLinearModel, self).__init__()
        self.linear = torch.nn.Linear(input_size, output_size)
        self.topk = TopK(topk, dim)
        self.dim = dim

    def forward(self, x):
        x = self.linear(x)
        v, k = self.topk(x)
        x = torch.index_select(x, self.dim, k.flatten())
        return x, v, k


def load_model(input_size=5, output_size=5, topk=4):
    model = SlicedLinearModel(input_size, output_size, topk=topk, dim=1)
    torch.nn.init.xavier_uniform_(model.linear.weight)
    torch.nn.init.zeros_(model.linear.bias)
    return model


def compile_onnx_model(model, input_x, onnx_path="slicing_test_dynamic.onnx", force_export=False):
    example_input = (input_x)
    input_names = ["input_x"]
    output_names = ["pred_x", "pred_v", "pred_k"]

    dynamic_axes = {
        "input_x" : {1: "sequence_length"},
        "pred_x" : {1: "sequence_length"},
        "pred_v" : {1: "sequence_length"},
        "pred_k" : {1: "sequence_length"}
    }

    if force_export or not os.path.exists(onnx_path):
        with torch.no_grad():
            torch.onnx.export(model, example_input, onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, export_params=True, do_constant_folding=True, opset_version=18)

            with torch.no_grad():
                pred_x, pred_v, pred_k = model(input_x)

                print(f"Calibration input_x[{input_x.shape}]:\n{input_x}")
                print(f"Calibration pred_x[{pred_x.shape}]:\n{pred_x}")
                print(f"Calibration pred_v[{pred_v.shape}]:\n{pred_v}")
                print(f"Calibration pred_k[{pred_k.shape}]:\n{pred_k}")
    else:
        print(f"ONNX model already exists at {onnx_path}")


TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
TRT_DTYPE_TO_TORCH = {
    trt.float32: torch.float32,
    trt.float16: torch.float16,
    trt.int32: torch.int32,
    trt.int64: torch.int64,
    trt.int8: torch.int8,
    trt.bool: torch.bool,
    trt.bfloat16: torch.bfloat16,
}


def compile_trt_model(input_shape, output_shapes, onnx_path, trt_path=None):
    # Convert onnx model to tensorrt
    batch_size, seq_length, input_size = input_shape
    _, topk, output_size = output_shapes[0]

    # Load the ONNX model
    onnx_model = onnx.load(onnx_path)

    # Create a TensorRT builder and network
    builder = trt.Builder(TRT_LOGGER)
    config = builder.create_builder_config()
    # Set cache
    cache = config.create_timing_cache(b"")
    config.set_timing_cache(cache, ignore_mismatch=False)
    config.set_flag(trt.BuilderFlag.FP16) # BF16
    # config.set_flag(trt.BuilderFlag.STRICT_TYPES)
    config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
    config.runtime_platform = trt.RuntimePlatform.SAME_AS_BUILD
    # config.set_flag(trt.BuilderFlag.ExportLayerInfo)

    # https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#build_engine_python
    # max_workspace = (1 << 30)
    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace)

    # Set profile
    # https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work_dynamic_shapes
    opt_profile = builder.create_optimization_profile()
    opt_profile.set_shape("input_x", (batch_size, 1, input_size), (batch_size, topk, input_size), (batch_size, 50, input_size))
    # opt_profile.set_shape("pred_x", (batch_size, topk, output_size), (batch_size, topk, output_size), (batch_size, topk, output_size))
    opt_profile.set_shape("pred_v", (batch_size, 1, 1), (batch_size, topk, 1), (batch_size, 50, 1))
    # opt_profile.set_shape("pred_k", (batch_size, topk, 1), (batch_size, topk, 1), (batch_size, topk, 1))
    # preprocessorConfig = builder.create_network_config()
    # preprocessorConfig.add_optimization_profile(profile)
    config.add_optimization_profile(opt_profile)

    # https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#version-compat
    # https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#explicit-implicit-batch
    flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(flag)

    # Create an ONNX parser and parse the ONNX model into the TensorRT network
    parser = trt.OnnxParser(network, TRT_LOGGER)
    # Parse the ONNX model from file
    # with open(onnx_path, "rb") as f:
    #     if not parser.parse(f.read()):
    #         print(f"ERROR: Failed to parse the ONNX file {onnx_path}")
    #         for error in range(parser.num_errors):
    #             print(parser.get_error(error))
    onnx_model_str = onnx_model.SerializeToString()
    if not parser.parse(onnx_model_str):
        print(f"ERROR: Failed to parse the ONNX file {onnx_model_str}")
        for error in range(parser.num_errors):
            print(parser.get_error(error))

    # exort_layer_info = "slicing_test_layer_info.json"
    # https://github.com/NVIDIA/TensorRT/issues/2247
    # https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/EngineInspector.html
    # https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Engine.html#tensorrt.ICudaEngine.create_engine_inspector
    # https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/BuilderConfig.html#tensorrt.ProfilingVerbosity
    # https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/ExecutionContext.html#tensorrt.IExecutionContext.report_to_profiler

    # Build the TensorRT engine
    engine_bytes = builder.build_serialized_network(network, config)
    if engine_bytes:
        # Save the TensorRT engine to a file
        trt_path = onnx_path[:-4]+ "trt" if trt_path is None else trt_path
        with open(trt_path, "wb") as f:
            f.write(engine_bytes)

        print("TensorRT engine saved successfully!")
    else:
        print("ERROR: Failed to build the TensorRT engine!")


def main():
    input_size = 5
    output_size = 5
    topk = 4
    model = load_model(input_size, output_size, topk)

    # Avoid accuracy issues with float16 betweeen CPU and GPU, use float32 to test onnx
    dtype = torch.bfloat16 # torch.bfloat16 # torch.float32 # torch.float16
    device = "cuda"
    model = model.to(dtype).to(device).eval()

    batch_size = 1
    seq_length = 20
    input_x = torch.rand(batch_size, seq_length, input_size).to(dtype).to(device)
    output_shapes = [torch.Size([batch_size, topk, input_size]), torch.Size([batch_size, seq_length, 1]), torch.Size([batch_size, topk, 1])]

    onnx_path="slicing_test_dynamic.onnx"
    compile_onnx_model(model, input_x, onnx_path=onnx_path, force_export=True)
    trt_path = onnx_path[:-4]+ "trt"
    compile_trt_model(input_shape=(batch_size, seq_length, input_size), output_shapes=output_shapes, onnx_path=onnx_path, trt_path=trt_path)

    # Onnx doesn't work with torch.bfloat16
    if dtype == torch.float32:
        verify_onnx_model(model, (batch_size, seq_length, input_size), onnx_path, dtype, device)

    verify_trt_model(model, (batch_size, seq_length, input_size), topk, trt_path, dtype, device, dynamic_shape=True)

    print("All checks pass!")


def verify_onnx_model(model, input_shape, onnx_path, dtype, device="cuda"):
    # Verify the ONNX model and TensorRT engine
    # Load the ONNX model
    onnx_model = onnx.load(onnx_path)
    # Create an inference session with the ONNX model
    session = onnxruntime.InferenceSession(onnx_model.SerializeToString())

    batch_size, seq_length, input_size = input_shape
    for i in range(2):
        seq_length = torch.randint(5, 51, (1,)).item()
        input_x = torch.rand(batch_size, seq_length, input_size).to(dtype).to(device)
        with torch.no_grad():
            pred_x, pred_v, pred_k = model(input_x)

        print(f"input_x[{input_x.shape}]:\n{input_x}")
        print(f"pred_x[{pred_x.shape}]:\n{pred_x}")
        print(f"pred_v[{pred_v.shape}]:\n{pred_v}")
        print(f"pred_k[{pred_k.shape}]:\n{pred_k}")

        # Prepare the input data
        input_data = {"input_x": input_x.cpu().numpy()}

        # Run the model
        output = session.run(None, input_data)

        # Get the output tensor
        onnx_pred_x = torch.from_numpy(output[0]).to(device)
        onnx_pred_v = torch.from_numpy(output[1]).to(device)
        onnx_pred_k = torch.from_numpy(output[2]).to(device)

        print(f"onnx run pred_x[{onnx_pred_x.shape}]:\n{onnx_pred_x}")
        print(f"onnx run pred_v[{onnx_pred_v.shape}]:\n{onnx_pred_v}")
        print(f"onnx run pred_k[{onnx_pred_k.shape}]:\n{onnx_pred_k}")

        # Verify if pred_x is equal to onnx_pred_x
        assert torch.allclose(pred_x, onnx_pred_x.to(device), atol=1e-3), "pred_x and onnx_pred_x are not equal"
        assert torch.allclose(pred_v, onnx_pred_v.to(device), atol=1e-3), "pred_v and onnx_pred_v are not equal"
        assert torch.allclose(pred_k, onnx_pred_k.to(device), atol=1e-3), "pred_k and onnx_pred_k are not equal"

    print("Onnx model checks pass!")


def verify_trt_model(model, input_shape, topk, trt_path, dtype, device="cuda", dynamic_shape=False):
    # Verify TensorRT engine
    # Load the TRT model

    profile_idx = 0 if dynamic_shape else None
    runtime = trt.Runtime(TRT_LOGGER)

    with open(trt_path, "rb") as f:
        engine = runtime.deserialize_cuda_engine(f.read())
        if engine is None:
            raise RuntimeError(f"Failed to reload TRT cuda engine from {trt_path}.")
        else:
            print(f"Finished loading TRT engine from {trt_path}")

    trt_stream = cudart.cudaStreamCreate()[1]
    trt_context = engine.create_execution_context()
    trt_inputs: dict[str, torch.Tensor] = {}
    trt_outputs: dict[str, torch.Tensor] = {}

    # Allocate buffers for inputs and outputs
    default_max_dim_size = 50
    with torch.cuda.device(device="cuda:0"):
        # current_device = torch.cuda.current_device()
        for i in range(engine.num_io_tensors):
            tensor_name = engine.get_tensor_name(i)

            trt_shape = engine.get_tensor_shape(tensor_name)
            if torch.any(torch.tensor(trt_shape) == -1):
                # Must define all dynamic shapes in the profile
                trt_shape_profile = engine.get_tensor_profile_shape(tensor_name, profile_idx)[-1]
                replace_negative_dims = False
                try:
                    if len(trt_shape_profile) == len(trt_shape):
                        trt_shape = trt_shape_profile
                    else:
                        replace_negative_dims = True
                except Exception as e:
                    replace_negative_dims = True

                if replace_negative_dims:
                    # breakpoint()
                    print(f"Profile shape {trt_shape_profile} does not match the engine shape {trt_shape}")
                    trt_shape = tuple([default_max_dim_size if dim == -1 else dim for dim in trt_shape])
                    print(f"Warnning: Using default max shape {trt_shape}")

            trt_tensor = torch.empty(tuple(trt_shape),
                                    dtype=TRT_DTYPE_TO_TORCH[engine.get_tensor_dtype(tensor_name)],
                                    device="cuda:0")

            if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
                trt_inputs[tensor_name] = trt_tensor
            else:
                trt_outputs[tensor_name] = trt_tensor
            trt_context.set_tensor_address(tensor_name, trt_tensor.data_ptr())

    batch_size, seq_length, input_size = input_shape
    try:
        for i in range(2):
            seq_length = torch.randint(5, 51, (1,)).item()
            input_x = torch.rand(batch_size, seq_length, input_size).to(dtype).to(device)
            output_shapes = [torch.Size([batch_size, topk, input_size]), torch.Size([batch_size, seq_length, 1]), torch.Size([batch_size, topk, 1])]

            with torch.no_grad():
                pred_x, pred_v, pred_k = model(input_x)

            print(f"input_x[{input_x.shape}]:\n{input_x}")
            print(f"pred_x[{pred_x.shape}]:\n{pred_x}")
            print(f"pred_v[{pred_v.shape}]:\n{pred_v}")
            print(f"pred_k[{pred_k.shape}]:\n{pred_k}")

            # Run TRT model
            input_data = {"input_x": input_x}
            with torch.cuda.device(device="cuda:0"):
                for key, value in input_data.items():
                    trt_inputs[key][:, :int(value.shape[1]), :].copy_(value)
                    trt_context.set_input_shape(key, tuple(value.shape))
                trt_context.execute_async_v3(trt_stream)

            trt_pred_x_, trt_pred_v_, trt_pred_k_ = trt_outputs["pred_x"], trt_outputs["pred_v"], trt_outputs["pred_k"]

            # Flatten the tensors
            trt_pred_x = trt_pred_x_.flatten()[:output_shapes[0].numel()].view(output_shapes[0])
            trt_pred_v = trt_pred_v_.flatten()[:output_shapes[1].numel()].view(output_shapes[1])
            trt_pred_k = trt_pred_k_.flatten()[:output_shapes[2].numel()].view(output_shapes[2])

            print(f"trt run pred_x[{trt_pred_x.shape}]:\n{trt_pred_x}")
            print(f"trt run pred_v[{trt_pred_v.shape}]:\n{trt_pred_v}")
            print(f"trt run pred_k[{trt_pred_k.shape}]:\n{trt_pred_k}")

            # Verify if pred_x is equal to onnx_pred_x
            # Use default atol=1e-5 and rtol=1.6e-2 for bfloat16
            # Due to accuracy issue in pred_v, the topk order may not be the same,
            # therefore, the following assertion sometimes does not pass.
            torch.testing.assert_close(pred_x, trt_pred_x.to(device)), "pred_x and onnx_pred_x are not equal"
            torch.testing.assert_close(pred_v, trt_pred_v.to(device)), "pred_v and onnx_pred_v are not equal"
            torch.testing.assert_close(pred_k, trt_pred_k.to(device)), "pred_k and onnx_pred_k are not equal"
    finally:
        # Deallocate buffers
        for i in range(engine.num_io_tensors):
            binding = engine[i]
            if binding in trt_inputs.keys():
                del trt_inputs[binding]
            else:
                del trt_outputs[binding]
            print(f"Deleted {binding}")


if __name__ == "__main__":
    main()

Multiple inputs

import os
import torch

import onnx
import onnxruntime
import tensorrt as trt
from cuda import cudart

# Reference:
# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#conditional-examples

@torch.jit.script
def sum_k_mods(items: torch.IntTensor, k: torch.Tensor) -> torch.Tensor:
    s = torch.zeros(1, dtype=torch.float, device=items.device)
    # Must use tensor operations when involves input tensor.
    m = torch.remainder(items, k) == 0
    for i, c in enumerate(items):
        if m[i]:
            s += c
    return s

# Error with the version below:
# Param k is considered as a constant and removed from the input list.
# @torch.jit.script
# def sum_k_mods(items: torch.IntTensor, k: torch.Tensor) -> torch.Tensor:
#     k = k.item()
#     s = torch.zeros(1, dtype=torch.float, device=items.device)
#     for c in items:
#         if c % k == 0:
#             s += c
#     return s

class ExampleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    # Only use tensor inputs
    def forward(self, items, k):
        return sum_k_mods(items, k)


def load_model():
    model = ExampleModel()
    return model


def generate_input(batch_size):
    input_x = torch.randint(1, 100, (batch_size,))
    k = torch.randint(2, 5, (1,))
    return input_x, k


def compile_onnx_model(model, inputs, onnx_path, force_export=True):
    example_input = tuple(inputs)
    input_names = ["input_x", "k"]
    output_names = ["pred_x"]

    if force_export or not os.path.exists(onnx_path):
        with torch.no_grad():
            torch.onnx.export(
                model,
                example_input,
                onnx_path,
                input_names=input_names,
                output_names=output_names,
                export_params=True,
                do_constant_folding=True,
                opset_version=18)

            # Check the onnx model
            pred_x = model(*example_input)
            for i, name in enumerate(input_names):
                print(f"Calibration {name}[{inputs[i].shape}]:\n{inputs[i]}")

            # Convert to tuple
            outputs = tuple(pred_x)
            for i, name in enumerate(output_names):
                print(f"Calibration {name}[{outputs[i].shape}]:\n{outputs[i]}")
        print(f"ONNX model is compiled and saved to {onnx_path}")
    else:
        print(f"ONNX model already exists at {onnx_path}")


TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
TRT_DTYPE_TO_TORCH = {
    trt.float32: torch.float32,
    trt.float16: torch.float16,
    trt.int32: torch.int32,
    trt.int64: torch.int64,
    trt.int8: torch.int8,
    trt.bool: torch.bool,
    trt.bfloat16: torch.bfloat16,
}


def compile_trt_model(onnx_path, trt_path=None):
    # Convert onnx model to tensorrt

    # Load the ONNX model
    onnx_model = onnx.load(onnx_path)

    # Create a TensorRT builder and network
    builder = trt.Builder(TRT_LOGGER)
    config = builder.create_builder_config()
    # Set cache
    cache = config.create_timing_cache(b"")
    config.set_timing_cache(cache, ignore_mismatch=False)
    config.set_flag(trt.BuilderFlag.FP16) # BF16
    # config.set_flag(trt.BuilderFlag.STRICT_TYPES)
    config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
    config.runtime_platform = trt.RuntimePlatform.SAME_AS_BUILD
    # config.set_flag(trt.BuilderFlag.ExportLayerInfo)

    # https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#build_engine_python
    # max_workspace = (1 << 30)
    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace)

    network = builder.create_network(EXPLICIT_BATCH)

    # Create an ONNX parser and parse the ONNX model into the TensorRT network
    parser = trt.OnnxParser(network, TRT_LOGGER)
    # Parse the ONNX model from file
    # with open(onnx_path, "rb") as f:
    #     if not parser.parse(f.read()):
    #         print(f"ERROR: Failed to parse the ONNX file {onnx_path}")
    #         for error in range(parser.num_errors):
    #             print(parser.get_error(error))
    onnx_model_str = onnx_model.SerializeToString()
    if not parser.parse(onnx_model_str):
        print(f"ERROR: Failed to parse the ONNX file {onnx_model_str}")
        for error in range(parser.num_errors):
            print(parser.get_error(error))

    # Build the TensorRT engine
    engine_bytes = builder.build_serialized_network(network, config)
    if engine_bytes:
        # Save the TensorRT engine to a file
        trt_path = onnx_path[:-4]+ "trt" if trt_path is None else trt_path
        with open(trt_path, "wb") as f:
            f.write(engine_bytes)

        print("TensorRT engine saved successfully!")
    else:
        print("ERROR: Failed to build the TensorRT engine!")


def verify_onnx_model(model, batch_size, onnx_path, dtype, device="cuda"):
    # Verify the ONNX model and TensorRT engine
    # Load the ONNX model
    onnx_model = onnx.load(onnx_path)
    # Create an inference session with the ONNX model
    session = onnxruntime.InferenceSession(onnx_model.SerializeToString())

    for i in range(2):
        input_x, k = generate_input(batch_size)
        inputs = (input_x.to(dtype).to(device), k.to(dtype).to(device))
        with torch.no_grad():
            pred_x = model(*inputs)

        print(f"input_x[{inputs[0].shape}]:\n{inputs}")
        print(f"pred_x[{pred_x.shape}]:\n{pred_x}")

        # Prepare the input data
        input_data = {"input_x": inputs[0].cpu().numpy(), "k": inputs[1].cpu().numpy()}
        # Run the model
        output = session.run(None, input_data)
        # Get the output tensor
        onnx_pred_x = torch.from_numpy(output[0]).to(pred_x.device)
        print(f"onnx run pred_x[{onnx_pred_x.shape}]:\n{onnx_pred_x}")

        # Verify if pred_x is equal to onnx_pred_x
        assert torch.allclose(pred_x, onnx_pred_x, atol=1e-3), "pred_x and onnx_pred_x are not equal"
    print("Onnx model checks pass!")


def verify_trt_model(model, batch_size, trt_path, dtype, device="cuda", dynamic_shape=False):
    # Verify TensorRT engine
    # Load the TRT model

    profile_idx = 0 if dynamic_shape else None
    runtime = trt.Runtime(TRT_LOGGER)

    with open(trt_path, "rb") as f:
        engine = runtime.deserialize_cuda_engine(f.read())
        if engine is None:
            raise RuntimeError(f"Failed to reload TRT cuda engine from {trt_path}.")
        else:
            print(f"Finished loading TRT engine from {trt_path}")

    trt_stream = cudart.cudaStreamCreate()[1]
    trt_context = engine.create_execution_context()
    trt_inputs: dict[str, torch.Tensor] = {}
    trt_outputs: dict[str, torch.Tensor] = {}

    # Allocate buffers for inputs and outputs
    default_max_dim_size = 50
    with torch.cuda.device(device="cuda:0"):
        # current_device = torch.cuda.current_device()
        for i in range(engine.num_io_tensors):
            tensor_name = engine.get_tensor_name(i)

            trt_shape = engine.get_tensor_shape(tensor_name)
            if torch.any(torch.tensor(trt_shape) == -1):
                # Must define all dynamic shapes in the profile
                trt_shape_profile = engine.get_tensor_profile_shape(tensor_name, profile_idx)[-1]
                replace_negative_dims = False
                try:
                    if len(trt_shape_profile) == len(trt_shape):
                        trt_shape = trt_shape_profile
                    else:
                        replace_negative_dims = True
                except Exception as e:
                    replace_negative_dims = True

                if replace_negative_dims:
                    # breakpoint()
                    print(f"Profile shape {trt_shape_profile} does not match the engine shape {trt_shape}")
                    trt_shape = tuple([default_max_dim_size if dim == -1 else dim for dim in trt_shape])
                    print(f"Warnning: Using default max shape {trt_shape}")

            trt_tensor = torch.empty(tuple(trt_shape),
                                    dtype=TRT_DTYPE_TO_TORCH[engine.get_tensor_dtype(tensor_name)],
                                    device="cuda:0")

            if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
                trt_inputs[tensor_name] = trt_tensor
            else:
                trt_outputs[tensor_name] = trt_tensor
            trt_context.set_tensor_address(tensor_name, trt_tensor.data_ptr())

    try:
        for i in range(2):
            input_x, k = generate_input(batch_size)
            inputs = (input_x.to(dtype).to(device), k.to(dtype).to(device))
            output_shapes = [torch.Size([1])]

            with torch.no_grad():
                pred_x = model(*inputs)

            print(f"input_x[{inputs[0].shape}]:\n{inputs}")
            print(f"pred_x[{pred_x.shape}]:\n{pred_x}")

            # Run TRT model
            input_data = {"input_x": inputs[0], "k": inputs[1]}
            with torch.cuda.device(device="cuda:0"):
                for key, value in input_data.items():
                    # Copy value to the input buffer
                    trt_inputs[key].copy_(value)
                    trt_context.set_input_shape(key, tuple(value.shape))
                trt_context.execute_async_v3(trt_stream)

            trt_pred_x_ = trt_outputs["pred_x"]
            # Flatten the tensors
            trt_pred_x = trt_pred_x_.flatten()[:output_shapes[0].numel()].view(output_shapes[0])
            print(f"trt run pred_x[{trt_pred_x.shape}]:\n{trt_pred_x}")

            # Verify if pred_x is equal to onnx_pred_x
            # Use default atol=1e-5 and rtol=1.6e-2 for bfloat16
            # Due to accuracy issue in pred_v, the topk order may not be the same,
            # therefore, the following assertion sometimes does not pass.
            torch.testing.assert_close(pred_x, trt_pred_x.to(pred_x.device)), "pred_x and onnx_pred_x are not equal"
    finally:
        # Deallocate buffers
        for i in range(engine.num_io_tensors):
            binding = engine[i]
            if binding in trt_inputs.keys():
                del trt_inputs[binding]
            else:
                del trt_outputs[binding]
            print(f"Deleted {binding}")


def main():
    model = load_model()

    # Avoid accuracy issues with float16 betweeen CPU and GPU, use float32 to test onnx
    dtype = torch.float32 # torch.bfloat16 # torch.float32 # torch.float16
    device = "cuda"
    model = model.to(dtype).to(device).eval()

    batch_size = 4
    input_x, k = generate_input(batch_size)
    inputs = (input_x.to(dtype).to(device), k.to(dtype).to(device))

    onnx_path="condition_k_mods.onnx"
    compile_onnx_model(model, inputs, onnx_path=onnx_path)
    trt_path = onnx_path[:-4]+ "trt"
    compile_trt_model(onnx_path=onnx_path, trt_path=trt_path)

    # Onnx doesn't work with torch.bfloat16
    if dtype == torch.float32:
        verify_onnx_model(model, batch_size, onnx_path, dtype, device)

    verify_trt_model(model, batch_size, trt_path, dtype, device, dynamic_shape=True)

    print("All checks pass!")


if __name__ == "__main__":
    main()

Reference

TensorRT diffusion demo code.

Flexible output allocation of dynamic shapes through subclassing the allocator.

TensorRT-LLM github.