@@ -19,6 +19,8 @@ const std::string GPU_DL =
1919llvm::cl::opt<std::string>
2020 libdevice (" libdevice" , llvm::cl::desc(" libdevice path for GPU kernels" ),
2121 llvm::cl::init (" /usr/local/cuda/nvvm/libdevice/libdevice.10.bc" ));
22+ llvm::cl::opt<std::string> ptxOutput (" ptx" ,
23+ llvm::cl::desc (" Output PTX to specified file" ));
2224
2325// Adapted from LLVM's GVExtractorPass, which is not externally available
2426// as a pass for the new pass manager.
@@ -684,10 +686,9 @@ getRequiredGVs(const std::vector<llvm::GlobalValue *> &kernels) {
684686 return std::vector<llvm::GlobalValue *>(keep.begin (), keep.end ());
685687}
686688
687- void moduleToPTX (llvm::Module *M, const std::string &filename,
688- std::vector<llvm::GlobalValue *> &kernels,
689- const std::string &cpuStr = " sm_30" ,
690- const std::string &featuresStr = " +ptx42" ) {
689+ std::string moduleToPTX (llvm::Module *M, std::vector<llvm::GlobalValue *> &kernels,
690+ const std::string &cpuStr = " sm_30" ,
691+ const std::string &featuresStr = " +ptx42" ) {
691692 llvm::Triple triple (llvm::Triple::normalize (GPU_TRIPLE));
692693 llvm::TargetLibraryInfoImpl tlii (triple);
693694
@@ -792,56 +793,25 @@ void moduleToPTX(llvm::Module *M, const std::string &filename,
792793 }
793794 }
794795
795- // Generate PTX file .
796+ // Generate PTX code .
796797 {
797- std::error_code errcode;
798- auto out = std::make_unique<llvm::ToolOutputFile>(filename, errcode,
799- llvm::sys::fs::OF_Text);
800- if (errcode)
801- compilationError (errcode.message ());
802- llvm::raw_pwrite_stream *os = &out->os ();
798+ llvm::SmallVector<char , 1024 > ptx;
799+ llvm::raw_svector_ostream os (ptx);
803800
804801 auto *mmiwp = new llvm::MachineModuleInfoWrapperPass (machine.get ());
805802 llvm::legacy::PassManager pm;
806803
807804 pm.add (new llvm::TargetLibraryInfoWrapperPass (tlii));
808- seqassertn (! machine->addPassesToEmitFile (pm, * os, nullptr ,
805+ bool fail = machine->addPassesToEmitFile (pm, os, nullptr ,
809806 llvm::CodeGenFileType::AssemblyFile,
810- /* DisableVerify=*/ false , mmiwp),
811- " could not add passes" );
807+ /* DisableVerify=*/ false , mmiwp);
808+ seqassertn (!fail, " could not add passes" );
809+
812810 const_cast <llvm::TargetLoweringObjectFile *>(machine->getObjFileLowering ())
813811 ->Initialize (mmiwp->getMMI ().getContext (), *machine);
814- pm.run (*M);
815- out->keep ();
816- }
817- }
818-
819- void addInitCall (llvm::Module *M, const std::string &filename) {
820- llvm::LLVMContext &context = M->getContext ();
821- llvm::IRBuilder<> B (context);
822- auto f = M->getOrInsertFunction (" seq_nvptx_load_module" , B.getVoidTy (), B.getPtrTy ());
823- auto *g = llvm::cast<llvm::Function>(f.getCallee ());
824- g->setDoesNotThrow ();
825-
826- auto *filenameVar = new llvm::GlobalVariable (
827- *M, llvm::ArrayType::get (llvm::Type::getInt8Ty (context), filename.length () + 1 ),
828- /* isConstant=*/ true , llvm::GlobalValue::PrivateLinkage,
829- llvm::ConstantDataArray::getString (context, filename), " .nvptx.filename" );
830- filenameVar->setUnnamedAddr (llvm::GlobalValue::UnnamedAddr::Global);
831-
832- if (auto *init = M->getFunction (" seq_init" )) {
833- seqassertn (init->hasOneUse (), " seq_init used more than once" );
834- auto *use = llvm::dyn_cast<llvm::CallBase>(init->use_begin ()->getUser ());
835- seqassertn (use, " seq_init use was not a call" );
836- B.SetInsertPoint (use->getNextNode ());
837- B.CreateCall (g, B.CreateBitCast (filenameVar, B.getPtrTy ()));
838- }
839812
840- for (auto &F : M->functions ()) {
841- if (F.hasFnAttribute (" jit" )) {
842- B.SetInsertPoint (F.getEntryBlock ().getFirstNonPHI ());
843- B.CreateCall (g, B.CreateBitCast (filenameVar, B.getPtrTy ()));
844- }
813+ pm.run (*M);
814+ return std::string (ptx.data (), ptx.size ());
845815 }
846816}
847817
@@ -894,16 +864,58 @@ void applyGPUTransformations(llvm::Module *M, const std::string &ptxFilename) {
894864 if (kernels.empty ())
895865 return ;
896866
897- std::string filename = ptxFilename.empty () ? M->getSourceFileName () : ptxFilename;
898- if (filename.empty () || filename[0 ] == ' <' )
899- filename = " kernel" ;
900- llvm::SmallString<128 > path (filename);
901- llvm::sys::path::replace_extension (path, " ptx" );
902- filename = path.str ();
903-
904- moduleToPTX (clone.get (), filename, kernels);
867+ auto ptx = moduleToPTX (clone.get (), kernels);
905868 cleanUpIntrinsics (M);
906- addInitCall (M, filename);
869+
870+ if (ptxOutput.getNumOccurrences () > 0 ) {
871+ std::error_code err;
872+ llvm::ToolOutputFile out (ptxOutput, err, llvm::sys::fs::OF_Text);
873+ seqassertn (!err, " Could not open file: {}" , err.message ());
874+ llvm::raw_ostream &os = out.os ();
875+ os << ptx;
876+ os.flush ();
877+ out.keep ();
878+ }
879+
880+ // Add ptx code as a global var
881+ auto *ptxVar = new llvm::GlobalVariable (
882+ *M, llvm::ArrayType::get (llvm::Type::getInt8Ty (context), ptx.length () + 1 ),
883+ /* isConstant=*/ true , llvm::GlobalValue::PrivateLinkage,
884+ llvm::ConstantDataArray::getString (context, ptx), " .ptx" );
885+
886+ ptxVar->setUnnamedAddr (llvm::GlobalValue::UnnamedAddr::Global);
887+
888+ // Find and patch direct calls to cuModuleLoadData()
889+ const std::string ptxTarget = " __codon_ptx__" ; // must match gpu.codon name
890+ llvm::SmallVector<llvm::Instruction *, 1 > callsToReplace;
891+ for (auto &F : *M) {
892+ for (auto &BB : F) {
893+ for (auto &I : BB) {
894+ auto *call = llvm::dyn_cast<llvm::CallBase>(&I);
895+ if (!call)
896+ continue ;
897+
898+ auto *callee = call->getCalledFunction ();
899+ if (!callee)
900+ continue ;
901+
902+ if (callee->getName () == ptxTarget && call->arg_size () == 0 )
903+ callsToReplace.push_back (call);
904+ }
905+ }
906+ }
907+
908+ for (auto *call : callsToReplace) {
909+ call->replaceAllUsesWith (ptxVar);
910+ call->dropAllReferences ();
911+ call->eraseFromParent ();
912+ }
913+
914+ // Delete __codon_ptx__() stub
915+ if (auto *F = M->getFunction (ptxTarget)) {
916+ seqassertn (F->use_empty (), " some __codon_ptx__() calls not replaced in module" );
917+ F->eraseFromParent ();
918+ }
907919}
908920
909921} // namespace ir
0 commit comments