123 lines
3.6 KiB
C++
123 lines
3.6 KiB
C++
|
//===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// In this example we will use an IR transform to optimize a module as it
|
||
|
// passes through LLJIT's IRTransformLayer.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
|
||
|
#include "llvm/IR/LegacyPassManager.h"
|
||
|
#include "llvm/Support/InitLLVM.h"
|
||
|
#include "llvm/Support/TargetSelect.h"
|
||
|
#include "llvm/Support/raw_ostream.h"
|
||
|
#include "llvm/Transforms/IPO.h"
|
||
|
#include "llvm/Transforms/Scalar.h"
|
||
|
|
||
|
#include "../ExampleModules.h"
|
||
|
|
||
|
using namespace llvm;
|
||
|
using namespace llvm::orc;
|
||
|
|
||
|
ExitOnError ExitOnErr;
|
||
|
|
||
|
// Example IR module.
|
||
|
//
|
||
|
// This IR contains a recursive definition of the factorial function:
|
||
|
//
|
||
|
// fac(n) | n == 0 = 1
|
||
|
// | otherwise = n * fac(n - 1)
|
||
|
//
|
||
|
// It also contains an entry function which calls the factorial function with
|
||
|
// an input value of 5.
|
||
|
//
|
||
|
// We expect the IR optimization transform that we build below to transform
|
||
|
// this into a non-recursive factorial function and an entry function that
|
||
|
// returns a constant value of 5!, or 120.
|
||
|
|
||
|
const llvm::StringRef MainMod =
|
||
|
R"(
|
||
|
|
||
|
define i32 @fac(i32 %n) {
|
||
|
entry:
|
||
|
%tobool = icmp eq i32 %n, 0
|
||
|
br i1 %tobool, label %return, label %if.then
|
||
|
|
||
|
if.then: ; preds = %entry
|
||
|
%arg = add nsw i32 %n, -1
|
||
|
%call_result = call i32 @fac(i32 %arg)
|
||
|
%result = mul nsw i32 %n, %call_result
|
||
|
br label %return
|
||
|
|
||
|
return: ; preds = %entry, %if.then
|
||
|
%final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
|
||
|
ret i32 %final_result
|
||
|
}
|
||
|
|
||
|
define i32 @entry() {
|
||
|
entry:
|
||
|
%result = call i32 @fac(i32 5)
|
||
|
ret i32 %result
|
||
|
}
|
||
|
|
||
|
)";
|
||
|
|
||
|
// A function object that creates a simple pass pipeline to apply to each
|
||
|
// module as it passes through the IRTransformLayer.
|
||
|
class MyOptimizationTransform {
|
||
|
public:
|
||
|
MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
|
||
|
PM->add(createTailCallEliminationPass());
|
||
|
PM->add(createFunctionInliningPass());
|
||
|
PM->add(createIndVarSimplifyPass());
|
||
|
PM->add(createCFGSimplificationPass());
|
||
|
}
|
||
|
|
||
|
Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
|
||
|
MaterializationResponsibility &R) {
|
||
|
TSM.withModuleDo([this](Module &M) {
|
||
|
dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
|
||
|
PM->run(M);
|
||
|
dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
|
||
|
});
|
||
|
return std::move(TSM);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
std::unique_ptr<legacy::PassManager> PM;
|
||
|
};
|
||
|
|
||
|
int main(int argc, char *argv[]) {
|
||
|
// Initialize LLVM.
|
||
|
InitLLVM X(argc, argv);
|
||
|
|
||
|
InitializeNativeTarget();
|
||
|
InitializeNativeTargetAsmPrinter();
|
||
|
|
||
|
ExitOnErr.setBanner(std::string(argv[0]) + ": ");
|
||
|
|
||
|
// (1) Create LLJIT instance.
|
||
|
auto J = ExitOnErr(LLJITBuilder().create());
|
||
|
|
||
|
// (2) Install transform to optimize modules when they're materialized.
|
||
|
J->getIRTransformLayer().setTransform(MyOptimizationTransform());
|
||
|
|
||
|
// (3) Add modules.
|
||
|
ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));
|
||
|
|
||
|
// (4) Look up the JIT'd function and call it.
|
||
|
auto EntrySym = ExitOnErr(J->lookup("entry"));
|
||
|
auto *Entry = (int (*)())EntrySym.getAddress();
|
||
|
|
||
|
int Result = Entry();
|
||
|
outs() << "--- Result ---\n"
|
||
|
<< "entry() = " << Result << "\n";
|
||
|
|
||
|
return 0;
|
||
|
}
|