Training and Inference performance

practice

Git repos

TorchTitan: A native PyTorch Library for large model training.

Torchao: PyTorch native quantization and sparsity for training and inference.

Diffusers-torchao: End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).

Gloo: Collective communications library with various primitives for multi-machine training.

TensorRT open source: NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs.

Pytorch cuda best practice

Best practices.

Profiling tools

Pytorch profiler

Profiling to understand torch.compile performance Example:

import torch
from torchvision.models import resnet18

model = resnet18().cuda()
inputs = [torch.randn((5, 3, 224, 224), device='cuda') for _ in range(10)]

model_c = torch.compile(model)

def fwd_bwd(inp):
    out = model_c(inp)
    out.sum().backward()

def warmup_compile():
    def fn(x):
        return x.sin().relu()

    x = torch.rand((2, 2), device='cuda', requires_grad=True)
    fn_c = torch.compile(fn)
    out = fn_c(x)
    out.sum().backward()

with torch.profiler.profile() as prof:
    with torch.profiler.record_function("warmup compile"):
        warmup_compile()

    with torch.profiler.record_function("resnet18 compile"):
        fwd_bwd(inputs[0])

prof.export_chrome_trace("trace_compile.json")

Why am I not seeing speedups? Graph Breaks Identify the cause of graph breaks:

import torch
import torch._dynamo as dynamo
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    print("woo")
    if b.sum() < 0:
        b = b * -1
    return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
  Break Reason 1:
    Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
    User Stack:
      <FrameSummary file foo.py, line 5 in toy_example>
  Break Reason 2:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
  ...
Out Guards:
  ...
"""

Google trace viewer

perfetto, originally “chrome://tracing”, which is deprecated.

Quickstart: SQL-based analysis and trace-based metrics.

TensorBoard

PyTorch Profiler With TensorBoard

Reference

TensorRT-LLM github.