llvm-for-llvmta/examples/OrcV2Examples/LLJITWithOptimizingIRTransform/LLJITWithOptimizingIRTransf...

123 lines
3.6 KiB
C++
Raw Permalink Normal View History

2022-04-25 10:02:23 +02:00
//===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// In this example we will use an IR transform to optimize a module as it
// passes through LLJIT's IRTransformLayer.
//
//===----------------------------------------------------------------------===//
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Scalar.h"
#include "../ExampleModules.h"
using namespace llvm;
using namespace llvm::orc;
ExitOnError ExitOnErr;
// Example IR module.
//
// This IR contains a recursive definition of the factorial function:
//
// fac(n) | n == 0 = 1
// | otherwise = n * fac(n - 1)
//
// It also contains an entry function which calls the factorial function with
// an input value of 5.
//
// We expect the IR optimization transform that we build below to transform
// this into a non-recursive factorial function and an entry function that
// returns a constant value of 5!, or 120.
const llvm::StringRef MainMod =
R"(
define i32 @fac(i32 %n) {
entry:
%tobool = icmp eq i32 %n, 0
br i1 %tobool, label %return, label %if.then
if.then: ; preds = %entry
%arg = add nsw i32 %n, -1
%call_result = call i32 @fac(i32 %arg)
%result = mul nsw i32 %n, %call_result
br label %return
return: ; preds = %entry, %if.then
%final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
ret i32 %final_result
}
define i32 @entry() {
entry:
%result = call i32 @fac(i32 5)
ret i32 %result
}
)";
// A function object that creates a simple pass pipeline to apply to each
// module as it passes through the IRTransformLayer.
class MyOptimizationTransform {
public:
MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
PM->add(createTailCallEliminationPass());
PM->add(createFunctionInliningPass());
PM->add(createIndVarSimplifyPass());
PM->add(createCFGSimplificationPass());
}
Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
MaterializationResponsibility &R) {
TSM.withModuleDo([this](Module &M) {
dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
PM->run(M);
dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
});
return std::move(TSM);
}
private:
std::unique_ptr<legacy::PassManager> PM;
};
int main(int argc, char *argv[]) {
// Initialize LLVM.
InitLLVM X(argc, argv);
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
ExitOnErr.setBanner(std::string(argv[0]) + ": ");
// (1) Create LLJIT instance.
auto J = ExitOnErr(LLJITBuilder().create());
// (2) Install transform to optimize modules when they're materialized.
J->getIRTransformLayer().setTransform(MyOptimizationTransform());
// (3) Add modules.
ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));
// (4) Look up the JIT'd function and call it.
auto EntrySym = ExitOnErr(J->lookup("entry"));
auto *Entry = (int (*)())EntrySym.getAddress();
int Result = Entry();
outs() << "--- Result ---\n"
<< "entry() = " << Result << "\n";
return 0;
}