Skip to content

[Bug] Segfault in LLVM COFFOptTable when combining TVM CUDA target with PyTorch Lightning + torchmetrics in the same process #18651

@tinywisdom

Description

@tinywisdom

Summary

When TVM is used in the same Python process as a PyTorch Lightning model that depends on torchmetrics.Accuracy, the process consistently segfaults inside LLVM initialization code.

The minimal pattern is:

  1. Import TVM and create a CUDA target (which initializes TVM’s LLVM/CUDA stack).
  2. Then import torchmetrics and pytorch_lightning, define a simple LightningModule, and run a single forward pass.
  3. The process crashes with a segmentation fault during shared library loading, with the top of the native backtrace pointing at llvm::opt::OptTable::buildPrefixChars() and COFFOptTable::COFFOptTable().

Without step (1), the same PyTorch Lightning + torchmetrics code runs normally on this environment.

Environment

From the script output:

  • OS: Linux x86_64 (glibc-based, from backtrace paths such as ./elf/dl-open.c)

  • Python: 3.10.16 | packaged by conda-forge | (main, Apr 8 2025, 20:53:32) [GCC 13.3.0]

  • NumPy: 2.2.6

  • PyTorch: 2.9.0+cu128

  • TVM:

    • Version: 0.22.0
    • LLVM version (reported by tvm.support.libinfo()): 17.0.6
    • GIT_COMMIT_HASH: 9dbf3f22ff6f44962472f9af310fda368ca85ef2
  • GPU / CUDA:

    • TVM target: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32
    • CUDA toolkit likely 12.8 (from PyTorch build tag +cu128)

Installed torchmetrics and pytorch_lightning are standard pip/conda versions compatible with the above PyTorch release.

Minimal Reproduction

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Minimal reproducer for a segfault when using TVM together with
PyTorch Lightning + torchmetrics (issue 70426 model).

Reproduction pattern:
  1) Import TVM and create a CUDA target (so LLVM/CUDA libraries are loaded).
  2) Import torchmetrics / pytorch_lightning.
  3) Define the following LightningModule (MyModel) and run a forward pass
     with input torch.rand(B, 1, 28, 28, dtype=torch.float32).

On my environment this script crashes with a segmentation fault
(before printing the final success message).
"""

import sys
import numpy as np
import torch

import tvm
from tvm import relax, tir  # keep the same style as in my TVM-based tools


def print_env_info():
    print("==== Environment ====")
    print("Python:", sys.version)
    print("NumPy version:", np.__version__)
    print("Torch version:", torch.__version__)
    print("TVM version:", getattr(tvm, "__version__", "unknown"))
    try:
        from tvm import support
        info = support.libinfo()
        print("TVM LLVM version:", info.get("LLVM_VERSION", "unknown"))
        print("TVM GIT_COMMIT_HASH:", info.get("GIT_COMMIT_HASH", "unknown"))
    except Exception as e:
        print("TVM libinfo not available:", repr(e))
    print("=====================\n")


def main():
    print_env_info()

    # 1) Force TVM to initialize CUDA / LLVM stack in this process
    print("[REPRO] Creating TVM cuda target to load LLVM/CUDA libraries ...")
    try:
        target = tvm.target.Target("cuda")
        print("[REPRO] TVM Target:", target)
    except Exception as e:
        print("[REPRO] Failed to create cuda target:", repr(e))
        # Even if this fails, we still continue to import the model stack.

    # 2) Now import torchmetrics / pytorch_lightning and define the model
    print("[REPRO] Importing torchmetrics / pytorch_lightning and defining MyModel ...")
    from torch import nn
    from torchmetrics import Accuracy
    import pytorch_lightning as pl

    class MyModel(pl.LightningModule):
        def __init__(self):
            super().__init__()
            self.encoder = nn.Sequential(
                nn.Linear(28 * 28, 64),
                nn.ReLU(),
                nn.Linear(64, 3),
            )
            self.decoder = nn.Sequential(
                nn.Linear(3, 64),
                nn.ReLU(),
                nn.Linear(64, 28 * 28),
            )

        def forward(self, x):
            # original forward from issue 70426: return embedding
            embedding = self.encoder(x.view(x.size(0), -1))
            return embedding

        def training_step(self, batch, batch_idx):
            # problematic code from the original Lightning training snippet
            device = self.device
            num_samples = 1000
            num_classes = 34
            Y = torch.ones(num_samples, dtype=torch.long, device=device)
            X = torch.zeros(num_samples, num_classes, device=device)
            accuracy = Accuracy(average="none", num_classes=num_classes).to(device)
            accuracy(X, Y)  # triggers computation during step

            # Original autoencoder training logic
            x, y = batch
            x = x.view(x.size(0), -1)
            z = self.encoder(x)
            x_hat = self.decoder(z)
            loss = nn.MSELoss()(x_hat, x)
            self.log("train_loss", loss)
            return loss

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=1e-3)

    def GetInput():
        # same input shape as in the original issue:
        # torch.rand(B, 1, 28, 28, dtype=torch.float32)
        return torch.rand(32, 1, 28, 28, dtype=torch.float32)

    # 3) Instantiate the model and run a simple forward pass
    print("[REPRO] Instantiating MyModel and running a forward pass ...")
    model = MyModel()
    x = GetInput()
    with torch.no_grad():
        out = model(x)
    print("[REPRO] Forward output shape:", tuple(out.shape))

    print("[REPRO] Script finished without segfault.")


if __name__ == "__main__":
    main()

On my machine, the script crashes with a segmentation fault right after:

[REPRO] Importing torchmetrics / pytorch_lightning and defining MyModel ...

The final [REPRO] Script finished without segfault. line is never printed.

Actual Behavior

Console output (truncated):

==== Environment ====
Python: 3.10.16 | packaged by conda-forge | (main, Apr  8 2025, 20:53:32) [GCC 13.3.0]
NumPy version: 2.2.6
Torch version: 2.9.0+cu128
TVM version: 0.22.0
TVM LLVM version: 17.0.6
TVM GIT_COMMIT_HASH: 9dbf3f22ff6f44962472f9af310fda368ca85ef2
=====================

[REPRO] Creating TVM cuda target to load LLVM/CUDA libraries ...
[REPRO] TVM Target: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32
[REPRO] Importing torchmetrics / pytorch_lightning and defining MyModel ...
!!!!!!! Segfault encountered !!!!!!!
  File "./signal/../sysdeps/unix/sysv/linux/x86_64/libc_sigaction.c", line 0, in 0x00007998d2c4251f
  File "<unknown>", line 0, in llvm::opt::OptTable::buildPrefixChars()
  File "<unknown>", line 0, in COFFOptTable::COFFOptTable()
  File "<unknown>", line 0, in _GLOBAL__sub_I_COFFDirectiveParser.cpp
  File "./elf/dl-init.c", line 70, in call_init
  File "./elf/dl-init.c", line 33, in call_init
  File "./elf/dl-init.c", line 117, in _dl_init
  File "./elf/dl-error-skeleton.c", line 182, in __GI__dl_catch_exception
  File "./elf/dl-open.c", line 808, in dl_open_worker
  File "./elf/dl-open.c", line 771, in dl_open_worker
  File "./elf/dl-error-skeleton.c", line 208, in __GI__dl_catch_exception
  File "./elf/dl-open.c", line 883, in _dl_open
  File "./dlfcn/dlopen.c", line 56, in dlopen_doit
  File "./elf/dl-error-skeleton.c", line 208, in __GI__dl_catch_exception
  File "./elf/dl-error-skeleton.c", line 227, in __GI__dl_catch_error
  File "./dlfcn/dlerror.c", line 138, in _dlerror_run
  File "./dlfcn/dlopen.c", line 71, in dlopen_implementation
  File "./dlfcn/dlopen.c", line 81, in ___dlopen
  File "/usr/local/src/conda/python-3.10.16/Python/dynload_shlib.c", line 100, in _PyImport_FindSharedFuncptr
  File "/usr/local/src/conda/python-3.10.16/Python/importdl.c", line 137, in _PyImport_LoadDynamicModuleWithSpec
  ...

Segmentation fault (core dumped)

The key part of the native backtrace is the LLVM initialization:

  in llvm::opt::OptTable::buildPrefixChars()
  in COFFOptTable::COFFOptTable()
  in _GLOBAL__sub_I_COFFDirectiveParser.cpp

The segfault happens during dynamic loading of a shared library triggered by module import, after TVM has already created a CUDA Target.

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions