5

使用 LLVM 实现一个简单编译器(一)

 3 years ago
source link: https://zhuanlan.zhihu.com/p/407854583
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

使用 LLVM 实现一个简单编译器(一)

已认证的官方帐号

作者:tomoyazhang,腾讯 PCG 后台开发工程师

1. 目标

这个系列来自 LLVM 的Kaleidoscope 教程,增加了我对代码的注释以及一些理解,修改了部分代码。现在开始我们要使用 LLVM 实现一个编译器,完成对如下代码的编译运行。

# 斐波那契数列函数定义
def fib(x)
    if x < 3 then
        1
    else
        fib(x - 1) + fib(x - 2)

fib(40)

# 函数声明
extern sin(arg)
extern cos(arg)
extern atan2(arg1 arg2)

# 声明后的函数可调用
atan2(sin(.4), cos(42))

这个语言称为 Kaleidoscope, 从代码可以看出,Kaleidoscope 支持函数、条件分支、数值计算等语言特性。为了方便,Kaleidoscope 唯一支持的数据类型为 float64, 所以示例中的所有数值都是 float64。

2. Lex

编译的第一个步骤称为 Lex, 词法分析,其功能是将文本输入转为多个 tokens, 比如对于如下代码:

atan2(sin(.4), cos(42))

就应该转为:

tokens = ["atan2", "(", "sin", "(", .4, ")", ",", "cos", "(", 42, ")", ")"]

接下来我们使用 C++来写这个 Lexer, 由于这是教程代码,所以并没有使用工程项目应有的设计:

// 如果不是以下5种情况,Lexer返回[0-255]的ASCII值,否则返回以下枚举值
enum Token {
  TOKEN_EOF = -1,         // 文件结束标识符
  TOKEN_DEF = -2,         // 关键字def
  TOKEN_EXTERN = -3,      // 关键字extern
  TOKEN_IDENTIFIER = -4,  // 名字
  TOKEN_NUMBER = -5       // 数值
};

std::string g_identifier_str;  // Filled in if TOKEN_IDENTIFIER
double g_number_val;           // Filled in if TOKEN_NUMBER

// 从标准输入解析一个Token并返回
int GetToken() {
  static int last_char = ' ';
  // 忽略空白字符
  while (isspace(last_char)) {
    last_char = getchar();
  }
  // 识别字符串
  if (isalpha(last_char)) {
    g_identifier_str = last_char;
    while (isalnum((last_char = getchar()))) {
      g_identifier_str += last_char;
    }
    if (g_identifier_str == "def") {
      return TOKEN_DEF;
    } else if (g_identifier_str == "extern") {
      return TOKEN_EXTERN;
    } else {
      return TOKEN_IDENTIFIER;
    }
  }
  // 识别数值
  if (isdigit(last_char) || last_char == '.') {
    std::string num_str;
    do {
      num_str += last_char;
      last_char = getchar();
    } while (isdigit(last_char) || last_char == '.');
    g_number_val = strtod(num_str.c_str(), nullptr);
    return TOKEN_NUMBER;
  }
  // 忽略注释
  if (last_char == '#') {
    do {
      last_char = getchar();
    } while (last_char != EOF && last_char != '\n' && last_char != '\r');
    if (last_char != EOF) {
      return GetToken();
    }
  }
  // 识别文件结束
  if (last_char == EOF) {
    return TOKEN_EOF;
  }
  // 直接返回ASCII
  int this_char = last_char;
  last_char = getchar();
  return this_char;
}

使用 Lexer 对之前的代码处理结果为(使用空格分隔 tokens):

def fib ( x ) if x < 3 then 1 else fib ( x - 1 ) + fib ( x - 2 ) fib ( 40 ) extern sin ( arg )
extern cos ( arg ) extern atan2 ( arg1 arg2 ) atan2 ( sin ( 0.4 ) , cos ( 42 ) )

Lexer 的输入是代码文本,输出是有序的一个个 Token。

3. Parser

编译的第二个步骤称为 Parse, 其功能是将 Lexer 输出的 tokens 转为 AST (Abstract Syntax Tree)。我们首先定义表达式的 AST Node:

// 所有 `表达式` 节点的基类
class ExprAST {
 public:
  virtual ~ExprAST() {}
};

// 字面值表达式
class NumberExprAST : public ExprAST {
 public:
  NumberExprAST(double val) : val_(val) {}

 private:
  double val_;
};

// 变量表达式
class VariableExprAST : public ExprAST {
 public:
  VariableExprAST(const std::string& name) : name_(name) {}

 private:
  std::string name_;
};

// 二元操作表达式
class BinaryExprAST : public ExprAST {
 public:
  BinaryExprAST(char op, std::unique_ptr<ExprAST> lhs,
                std::unique_ptr<ExprAST> rhs)
      : op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}

 private:
  char op_;
  std::unique_ptr<ExprAST> lhs_;
  std::unique_ptr<ExprAST> rhs_;
};

// 函数调用表达式
class CallExprAST : public ExprAST {
 public:
  CallExprAST(const std::string& callee,
              std::vector<std::unique_ptr<ExprAST>> args)
      : callee_(callee), args_(std::move(args)) {}

 private:
  std::string callee_;
  std::vector<std::unique_ptr<ExprAST>> args_;
};

为了便于理解,关于条件表达式的内容放在后面,这里暂不考虑。接着我们定义函数声明和函数的 AST Node:

// 函数接口
class PrototypeAST {
 public:
  PrototypeAST(const std::string& name, std::vector<std::string> args)
      : name_(name), args_(std::move(args)) {}

  const std::string& name() const { return name_; }

 private:
  std::string name_;
  std::vector<std::string> args_;
};

// 函数
class FunctionAST {
 public:
  FunctionAST(std::unique_ptr<PrototypeAST> proto,
              std::unique_ptr<ExprAST> body)
      : proto_(std::move(proto)), body_(std::move(body)) {}

 private:
  std::unique_ptr<PrototypeAST> proto_;
  std::unique_ptr<ExprAST> body_;
};

接下来我们要进行 Parse, 在正式 Parse 前,定义如下函数方便后续处理:

int g_current_token;  // 当前待处理的Token
int GetNextToken() {
  return g_current_token = GetToken();
}

首先我们处理最简单的字面值:

// numberexpr ::= number
std::unique_ptr<ExprAST> ParseNumberExpr() {
  auto result = std::make_unique<NumberExprAST>(g_number_val);
  GetNextToken();
  return std::move(result);
}

这段程序非常简单,当前 Token 为 TOKEN_NUMBER 时被调用,使用 g_number_val,创建一个 NumberExprAST, 因为当前 Token 处理完毕,让 Lexer 前进一个 Token, 最后返回。接着我们处理圆括号操作符、变量、函数调用:

// parenexpr ::= ( expression )
std::unique_ptr<ExprAST> ParseParenExpr() {
  GetNextToken();  // eat (
  auto expr = ParseExpression();
  GetNextToken();  // eat )
  return expr;
}

/// identifierexpr
///   ::= identifier
///   ::= identifier ( expression, expression, ..., expression )
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
  std::string id = g_identifier_str;
  GetNextToken();
  if (g_current_token != '(') {
    return std::make_unique<VariableExprAST>(id);
  } else {
    GetNextToken();  // eat (
    std::vector<std::unique_ptr<ExprAST>> args;
    while (g_current_token != ')') {
      args.push_back(ParseExpression());
      if (g_current_token == ')') {
        break;
      } else {
        GetNextToken();  // eat ,
      }
    }
    GetNextToken();  // eat )
    return std::make_unique<CallExprAST>(id, std::move(args));
  }
}

上面代码中的 ParseExpression 与 ParseParenExpr 等存在循环依赖,这里按照其名字理解意思即可,具体实现在后面。我们将 NumberExpr、ParenExpr、IdentifierExpr 视为 PrimaryExpr, 封装 ParsePrimary 方便后续调用:

/// primary
///   ::= identifierexpr
///   ::= numberexpr
///   ::= parenexpr
std::unique_ptr<ExprAST> ParsePrimary() {
  switch (g_current_token) {
    case TOKEN_IDENTIFIER: return ParseIdentifierExpr();
    case TOKEN_NUMBER: return ParseNumberExpr();
    case '(': return ParseParenExpr();
    default: return nullptr;
  }
}

接下来我们考虑如何处理二元操作符,为了方便,Kaleidoscope 只支持 4 种二元操作符,优先级为:

'<' < '+' = '-' < '*'

即'<'的优先级最低,而'*'的优先级最高,在代码中实现为:

// 定义优先级
const std::map<char, int> g_binop_precedence = {
    {'<', 10}, {'+', 20}, {'-', 20}, {'*', 40}};

// 获得当前Token的优先级
int GetTokenPrecedence() {
  auto it = g_binop_precedence.find(g_current_token);
  if (it != g_binop_precedence.end()) {
    return it->second;
  } else {
    return -1;
  }
}

对于带优先级的二元操作符的解析,我们会将其分成多个片段。比如一个表达式:

a + b + (c + d) * e * f + g

首先解析 a, 然后处理多个二元组:

[+, b], [+, (c+d)], [*, e], [*, f], [+, g]

即,复杂表达式可以抽象为一个 PrimaryExpr 跟着多个[binop, PrimaryExpr]二元组,注意由于圆括号属于 PrimaryExpr, 所以这里不需要考虑怎么特殊处理(c+d),因为会被 ParsePrimary 自动处理。

// parse
//   lhs [binop primary] [binop primary] ...
// 如遇到优先级小于min_precedence的操作符,则停止
std::unique_ptr<ExprAST> ParseBinOpRhs(int min_precedence,
                                       std::unique_ptr<ExprAST> lhs) {
  while (true) {
    int current_precedence = GetTokenPrecedence();
    if (current_precedence < min_precedence) {
      // 如果当前token不是二元操作符,current_precedence为-1, 结束任务
      // 如果遇到优先级更低的操作符,也结束任务
      return lhs;
    }
    int binop = g_current_token;
    GetNextToken();  // eat binop
    auto rhs = ParsePrimary();
    // 现在我们有两种可能的解析方式
    //    * (lhs binop rhs) binop unparsed
    //    * lhs binop (rhs binop unparsed)
    int next_precedence = GetTokenPrecedence();
    if (current_precedence < next_precedence) {
      // 将高于current_precedence的右边的操作符处理掉返回
      rhs = ParseBinOpRhs(current_precedence + 1, std::move(rhs));
    }
    lhs =
        std::make_unique<BinaryExprAST>(binop, std::move(lhs), std::move(rhs));
    // 继续循环
  }
}

// expression
//   ::= primary [binop primary] [binop primary] ...
std::unique_ptr<ExprAST> ParseExpression() {
  auto lhs = ParsePrimary();
  return ParseBinOpRhs(0, std::move(lhs));
}

最复杂的部分完成后,按部就班把 function 写完:

// prototype
//   ::= id ( id id ... id)
std::unique_ptr<PrototypeAST> ParsePrototype() {
  std::string function_name = g_identifier_str;
  GetNextToken();
  std::vector<std::string> arg_names;
  while (GetNextToken() == TOKEN_IDENTIFIER) {
    arg_names.push_back(g_identifier_str);
  }
  GetNextToken();  // eat )
  return std::make_unique<PrototypeAST>(function_name, std::move(arg_names));
}

// definition ::= def prototype expression
std::unique_ptr<FunctionAST> ParseDefinition() {
  GetNextToken();  // eat def
  auto proto = ParsePrototype();
  auto expr = ParseExpression();
  return std::make_unique<FunctionAST>(std::move(proto), std::move(expr));
}

// external ::= extern prototype
std::unique_ptr<PrototypeAST> ParseExtern() {
  GetNextToken();  // eat extern
  return ParsePrototype();
}

最后,我们为顶层的代码实现匿名 function:

// toplevelexpr ::= expression
std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
  auto expr = ParseExpression();
  auto proto = std::make_unique<PrototypeAST>("", std::vector<std::string>());
  return std::make_unique<FunctionAST>(std::move(proto), std::move(expr));
}

顶层代码的意思是放在全局而不放在 function 内定义的一些执行语句比如变量赋值,函数调用等。编写一个 main 函数:

int main() {
  GetNextToken();
  while (true) {
    switch (g_current_token) {
      case TOKEN_EOF: return 0;
      case TOKEN_DEF: {
        ParseDefinition();
        std::cout << "parsed a function definition" << std::endl;
        break;
      }
      case TOKEN_EXTERN: {
        ParseExtern();
        std::cout << "parsed a extern" << std::endl;
        break;
      }
      default: {
        ParseTopLevelExpr();
        std::cout << "parsed a top level expr" << std::endl;
        break;
      }
    }
  }
  return 0;
}
clang++ main.cpp `llvm-config --cxxflags --ldflags --libs`

输入如下代码进行测试:

def foo(x y)
    x + foo(y, 4)

def foo(x y)
    x + y

y

extern sin(a)

得到输出:

parsed a function definition
parsed a function definition
parsed a top level expr
parsed a extern

至此成功将 Lexer 输出的 tokens 转为 AST。

4. Code Generation to LLVM IR

终于开始 codegen 了,首先我们 include 一些 LLVM 头文件,定义一些全局变量:

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"

// 记录了LLVM的核心数据结构,比如类型和常量表,不过我们不太需要关心它的内部
llvm::LLVMContext g_llvm_context;
// 用于创建LLVM指令
llvm::IRBuilder<> g_ir_builder(g_llvm_context);
// 用于管理函数和全局变量,可以粗浅地理解为类c++的编译单元(单个cpp文件)
llvm::Module g_module("my cool jit", g_llvm_context);
// 用于记录函数的变量参数
std::map<std::string, llvm::Value*> g_named_values;

然后给每个 AST Class 增加一个 CodeGen 接口:

// 所有 `表达式` 节点的基类
class ExprAST {
 public:
  virtual ~ExprAST() {}
  virtual llvm::Value* CodeGen() = 0;
};

// 字面值表达式
class NumberExprAST : public ExprAST {
 public:
  NumberExprAST(double val) : val_(val) {}
  llvm::Value* CodeGen() override;

 private:
  double val_;
};

首先实现 NumberExprAST 的 CodeGen:

llvm::Value* NumberExprAST::CodeGen() {
  return llvm::ConstantFP::get(g_llvm_context, llvm::APFloat(val_));
}

由于 Kaleidoscope 只有一种数据类型 FP64, 所以直接调用 ConstantFP 传入即可,APFloat 是 llvm 内部的数据结构,用于存储 Arbitrary Precision Float. 在 LLVM IR 中,所有常量是唯一且共享的,所以这里使用的 get 而不是 new/create。

然后实现 VariableExprAST 的 CodeGen:

llvm::Value* VariableExprAST::CodeGen() {
  return g_named_values.at(name_);
}

由于 Kaleidoscope 的 VariableExpr 只存在于函数内对函数参数的引用,我们假定函数参数已经被注册到 g_name_values 中,所以 VariableExpr 直接查表返回即可。

接着实现 BinaryExprAST, 分别 codegen lhs, rhs 然后创建指令处理 lhs, rhs 即可:

llvm::Value* BinaryExprAST::CodeGen() {
  llvm::Value* lhs = lhs_->CodeGen();
  llvm::Value* rhs = rhs_->CodeGen();
  switch (op_) {
    case '<': {
      llvm::Value* tmp = g_ir_builder.CreateFCmpULT(lhs, rhs, "cmptmp");
      // 把 0/1 转为 0.0/1.0
      return g_ir_builder.CreateUIToFP(
          tmp, llvm::Type::getDoubleTy(g_llvm_context), "booltmp");
    }
    case '+': return g_ir_builder.CreateFAdd(lhs, rhs, "addtmp");
    case '-': return g_ir_builder.CreateFSub(lhs, rhs, "subtmp");
    case '*': return g_ir_builder.CreateFMul(lhs, rhs, "multmp");
    default: return nullptr;
  }
}

实现 CallExprAST:

llvm::Value* CallExprAST::CodeGen() {
  // g_module中存储了全局变量/函数等
  llvm::Function* callee = g_module.getFunction(callee_);

  std::vector<llvm::Value*> args;
  for (std::unique_ptr<ExprAST>& arg_expr : args_) {
    args.push_back(arg_expr->CodeGen());
  }
  return g_ir_builder.CreateCall(callee, args, "calltmp");
}

实现 ProtoTypeAST:

llvm::Value* PrototypeAST::CodeGen() {
  // 创建kaleidoscope的函数类型 double (doube, double, ..., double)
  std::vector<llvm::Type*> doubles(args_.size(),
                                   llvm::Type::getDoubleTy(g_llvm_context));
  // 函数类型是唯一的,所以使用get而不是new/create
  llvm::FunctionType* function_type = llvm::FunctionType::get(
      llvm::Type::getDoubleTy(g_llvm_context), doubles, false);
  // 创建函数, ExternalLinkage意味着函数可能不在当前module中定义,在当前module
  // 即g_module中注册名字为name_, 后面可以使用这个名字在g_module中查询
  llvm::Function* func = llvm::Function::Create(
      function_type, llvm::Function::ExternalLinkage, name_, &g_module);
  // 增加IR可读性,设置function的argument name
  int index = 0;
  for (auto& arg : func->args()) {
    arg.setName(args_[index++]);
  }
  return func;
}

实现 FunctionAST:

llvm::Value* FunctionAST::CodeGen() {
  // 检查函数声明是否已完成codegen(比如之前的extern声明), 如果没有则执行codegen
  llvm::Function* func = g_module.getFunction(proto_->name());
  if (func == nullptr) {
    func = proto_->CodeGen();
  }
  // 创建一个Block并且设置为指令插入位置。
  // llvm block用于定义control flow graph, 由于我们暂不实现control flow, 创建
  // 一个单独的block即可
  llvm::BasicBlock* block =
      llvm::BasicBlock::Create(g_llvm_context, "entry", func);
  g_ir_builder.SetInsertPoint(block);
  // 将函数参数注册到g_named_values中,让VariableExprAST可以codegen
  g_named_values.clear();
  for (llvm::Value& arg : func->args()) {
    g_named_values[arg.getName()] = &arg;
  }
  // codegen body然后return
  llvm::Value* ret_val = body_->CodeGen();
  g_ir_builder.CreateRet(ret_val);
  llvm::verifyFunction(*func);
  return func;
}

至此,所有 codegen 都已完成,修改 main:

int main() {
  GetNextToken();
  while (true) {
    switch (g_current_token) {
      case TOKEN_EOF: return 0;
      case TOKEN_DEF: {
        auto ast = ParseDefinition();
        std::cout << "parsed a function definition" << std::endl;
        ast->CodeGen()->print(llvm::errs());
        std::cerr << std::endl;
        break;
      }
      case TOKEN_EXTERN: {
        auto ast = ParseExtern();
        std::cout << "parsed a extern" << std::endl;
        ast->CodeGen()->print(llvm::errs());
        std::cerr << std::endl;
        break;
      }
      default: {
        auto ast = ParseTopLevelExpr();
        std::cout << "parsed a top level expr" << std::endl;
        ast->CodeGen()->print(llvm::errs());
        std::cerr << std::endl;
        break;
      }
    }
  }
  return 0;
}

输入测试:

4 + 5

def foo(a b)
    a*a + 2*a*b + b*b

foo(2, 3)

def bar(a)
    foo(a, 4) + bar(31337)

extern cos(x)

cos(1.234)

得到输出:

parsed a top level expr
define double @0() {
entry:
  ret double 9.000000e+00
}

parsed a function definition
define double @foo(double %a, double %b) {
entry:
  %multmp = fmul double %a, %a
  %multmp1 = fmul double 2.000000e+00, %a
  %multmp2 = fmul double %multmp1, %b
  %addtmp = fadd double %multmp, %multmp2
  %multmp3 = fmul double %b, %b
  %addtmp4 = fadd double %addtmp, %multmp3
  ret double %addtmp4
}

parsed a top level expr
define double @1() {
entry:
  %calltmp = call double @foo(double 2.000000e+00, double 3.000000e+00)
  ret double %calltmp
}

parsed a function definition
define double @bar(double %a) {
entry:
  %calltmp = call double @foo(double %a, double 4.000000e+00)
  %calltmp1 = call double @bar(double 3.133700e+04)
  %addtmp = fadd double %calltmp, %calltmp1
  ret double %addtmp
}

parsed a extern
declare double @cos(double)

parsed a top level expr
define double @2() {
entry:
  %calltmp = call double @cos(double 1.234000e+00)
  ret double %calltmp
}

至此,我们已成功将 Parser 输出的 AST 转为 LLVM IR。

5. Optimizer

我们使用上一节的程序处理如下代码:

def test(x)
    1 + 2 + x

可以得到:

parsed a function definition
define double @test(double %x) {
entry:
  %addtmp = fadd double 3.000000e+00, %x
  ret double %addtmp
}

可以看到,生成的指令直接是 1+2 的结果,而没有 1 + 2 的指令,这种自动把常量计算完毕而不是生成加法指令的优化称为 Constant Folding。

在大部分时候仅有这个优化仍然不够,比如如下代码:

def test(x)
    (1 + 2 + x) * (x + (1 + 2))

可以得到编译结果:

parsed a function definition
define double @test(double %x) {
entry:
  %addtmp = fadd double 3.000000e+00, %x
  %addtmp1 = fadd double %x, 3.000000e+00
  %multmp = fmul double %addtmp, %addtmp1
  ret double %multmp
}

生成了两个加法指令,但最优做法只需要一个加法即可,因为乘法的两边 lhs 和 rhs 是相等的。

这需要其他的优化技术,llvm 以"passes"的形式提供,llvm 中的 passes 可以选择是否启用,可以设置 passes 的顺序。

这里我们对每个函数单独做优化,定义 g_fpm, 增加几个 passes:

llvm::legacy::FunctionPassManager g_fpm(&g_module);

int main() {
  g_fpm.add(llvm::createInstructionCombiningPass());
  g_fpm.add(llvm::createReassociatePass());
  g_fpm.add(llvm::createGVNPass());
  g_fpm.add(llvm::createCFGSimplificationPass());
  g_fpm.doInitialization();
  ...
}

在 FunctionAST 的 CodeGen 中增加一句:

  llvm::Value* ret_val = body_->CodeGen();
  g_ir_builder.CreateRet(ret_val);
  llvm::verifyFunction(*func);
  g_fpm.run(*func); // 增加这句
  return func;

即启动了对每个 function 的优化,接下来测试之前的代码:

parsed a function definition
define double @test(double %x) {
entry:
  %addtmp = fadd double %x, 3.000000e+00
  %multmp = fmul double %addtmp, %addtmp
  ret double %multmp
}

可以看到,和我们期望的一样,加法指令减少到一个。

6. Adding a JIT Compiler

由于 JIT 模式中我们需要反复创建新的 module, 所以我们将全局变量 g_module 改为 unique_ptr。

// 用于管理函数和全局变量,可以粗浅地理解为类c++的编译单元(单个cpp文件)
std::unique_ptr<llvm::Module> g_module =
    std::make_unique<llvm::Module>("my cool jit", g_llvm_context);

为了专注于 JIT,我们可以把优化的 passes 删掉。

修改 ParseTopLevelExpr,给 PrototypeAST 命名为__anon_expr, 让我们后面可以通过这个名字找到它。

// toplevelexpr ::= expression
std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
  auto expr = ParseExpression();
  auto proto =
      std::make_unique<PrototypeAST>("__anon_expr", std::vector<std::string>());
  return std::make_unique<FunctionAST>(std::move(proto), std::move(expr));
}

然后我们从 llvm-project 中拷贝一份代码 llvm/examples/Kaleidoscope/include/KaleidoscopeJIT.h 到本地再 include, 其定义了 KaleidoscopeJIT 类,关于这个类,在后面会做解读,这里先不管。

定义全局变量 g_jit, 并使用 InitializeNativeTarget*函数初始化环境。

#include "KaleidoscopeJIT.h"

std::unique_ptr<llvm::orc::KaleidoscopeJIT> g_jit;

int main() {
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  llvm::InitializeNativeTargetAsmParser();
  g_jit.reset(new llvm::orc::KaleidoscopeJIT);
  g_module->setDataLayout(g_jit->getTargetMachine().createDataLayout());
  ...
}

修改 main 处理 top level expr 的代码为:

        auto ast = ParseTopLevelExpr();
        std::cout << "parsed a top level expr" << std::endl;
        ast->CodeGen()->print(llvm::errs());
        std::cout << std::endl;
        auto h = g_jit->addModule(std::move(g_module));
        // 重新创建g_module在下次使用
        g_module =
            std::make_unique<llvm::Module>("my cool jit", g_llvm_context);
        g_module->setDataLayout(g_jit->getTargetMachine().createDataLayout());
        // 通过名字找到编译的函数符号
        auto symbol = g_jit->findSymbol("__anon_expr");
        // 强转为C函数指针
        double (*fp)() = (double (*)())(symbol.getAddress().get());
        // 执行输出
        std::cout << fp() << std::endl;
        g_jit->removeModule(h);
        break;
4 + 5

def foo(a b)
    a*a + 2*a*b + b*b

foo(2, 3)

得到输出:

parsed a top level expr
define double @__anon_expr() {
entry:
  ret double 9.000000e+00
}

9
parsed a function definition
define double @foo(double %a, double %b) {
entry:
  %multmp = fmul double %a, %a
  %multmp1 = fmul double 2.000000e+00, %a
  %multmp2 = fmul double %multmp1, %b
  %addtmp = fadd double %multmp, %multmp2
  %multmp3 = fmul double %b, %b
  %addtmp4 = fadd double %addtmp, %multmp3
  ret double %addtmp4
}

parsed a top level expr
define double @__anon_expr() {
entry:
  %calltmp = call double @foo(double 2.000000e+00, double 3.000000e+00)
  ret double %calltmp
}

25

可以看到代码已经顺利执行,但现在的实现仍然是有问题的,比如上面的输入,foo 函数的定义和调用是被归在同一个 module 中,当第一次调用完成后,由于我们 removeModule, 第二次调用 foo 会失败。

在解决这个问题之前,我们先把 main 函数内对不同 TOKEN 的处理拆成多个函数,如下:

void ReCreateModule() {
  g_module = std::make_unique<llvm::Module>("my cool jit", g_llvm_context);
  g_module->setDataLayout(g_jit->getTargetMachine().createDataLayout());
}

void ParseDefinitionToken() {
  auto ast = ParseDefinition();
  std::cout << "parsed a function definition" << std::endl;
  ast->CodeGen()->print(llvm::errs());
  std::cerr << std::endl;
}

void ParseExternToken() {
  auto ast = ParseExtern();
  std::cout << "parsed a extern" << std::endl;
  ast->CodeGen()->print(llvm::errs());
  std::cerr << std::endl;
}

void ParseTopLevel() {
  auto ast = ParseTopLevelExpr();
  std::cout << "parsed a top level expr" << std::endl;
  ast->CodeGen()->print(llvm::errs());
  std::cout << std::endl;
  auto h = g_jit->addModule(std::move(g_module));
  // 重新创建g_module在下次使用
  ReCreateModule();
  // 通过名字找到编译的函数符号
  auto symbol = g_jit->findSymbol("__anon_expr");
  // 强转为C函数指针
  double (*fp)() = (double (*)())(symbol.getAddress().get());
  // 执行输出
  std::cout << fp() << std::endl;
  g_jit->removeModule(h);
}

int main() {
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();
  llvm::InitializeNativeTargetAsmParser();
  g_jit.reset(new llvm::orc::KaleidoscopeJIT);
  g_module->setDataLayout(g_jit->getTargetMachine().createDataLayout());

  GetNextToken();
  while (true) {
    switch (g_current_token) {
      case TOKEN_EOF: return 0;
      case TOKEN_DEF: ParseDefinitionToken(); break;
      case TOKEN_EXTERN: ParseExternToken(); break;
      default: ParseTopLevel(); break;
    }
  }
  return 0;
}

为了解决第二次调用 foo 失败的问题,我们需要让 function 和 top level expr 处于不同的 Module, 而处于不同 Module 的话,CallExprAST 的 CodeGen 在当前 module 会找不到 function, 所以需要自动在 CallExprAST 做 CodeGen 时在当前 Module 声明这个函数,即自动地增加 extern, 也就是在当前 Module 自动做对应 PrototypeAST 的 CodeGen.

首先,增加一个全局变量存储从函数名到函数接口的映射,并增加一个查询函数。

std::map<std::string, std::unique_ptr<PrototypeAST>> name2proto_ast;

llvm::Function* GetFunction(const std::string& name) {
  llvm::Function* callee = g_module->getFunction(name);
  if (callee != nullptr) {  // 当前module存在函数定义
    return callee;
  } else {
    // 声明函数
    return name2proto_ast.at(name)->CodeGen();
  }
}

更改 CallExprAST 的 CodeGen, 让其使用上面定义的 GetFuntion:

llvm::Value* CallExprAST::CodeGen() {
  llvm::Function* callee = GetFunction(callee_);

  std::vector<llvm::Value*> args;
  for (std::unique_ptr<ExprAST>& arg_expr : args_) {
    args.push_back(arg_expr->CodeGen());
  }
  return g_ir_builder.CreateCall(callee, args, "calltmp");
}

更改 FunctionAST 的 CodeGen, 让其将结果写入 name2proto_ast:

llvm::Value* FunctionAST::CodeGen() {
  PrototypeAST& proto = *proto_;
  name2proto_ast[proto.name()] = std::move(proto_);  // transfer ownership
  llvm::Function* func = GetFunction(proto.name());
  // 创建一个Block并且设置为指令插入位置。
  // llvm block用于定义control flow graph, 由于我们暂不实现control flow, 创建
  // 一个单独的block即可
  llvm::BasicBlock* block =
      llvm::BasicBlock::Create(g_llvm_context, "entry", func);
  g_ir_builder.SetInsertPoint(block);
  // 将函数参数注册到g_named_values中,让VariableExprAST可以codegen
  g_named_values.clear();
  for (llvm::Value& arg : func->args()) {
    g_named_values[arg.getName()] = &arg;
  }
  // codegen body然后return
  llvm::Value* ret_val = body_->CodeGen();
  g_ir_builder.CreateRet(ret_val);
  llvm::verifyFunction(*func);
  return func;
}

修改 ParseExternToken 将结果写入 name2proto_ast:

void ParseExternToken() {
  auto ast = ParseExtern();
  std::cout << "parsed a extern" << std::endl;
  ast->CodeGen()->print(llvm::errs());
  std::cerr << std::endl;
  name2proto_ast[ast->name()] = std::move(ast);
}

修改 ParseDefinitionToken 让其使用独立 Module:

void ParseDefinitionToken() {
  auto ast = ParseDefinition();
  std::cout << "parsed a function definition" << std::endl;
  ast->CodeGen()->print(llvm::errs());
  std::cerr << std::endl;
  g_jit->addModule(std::move(g_module));
  ReCreateModule();
}

修改完毕,输入测试:

def foo(x)
    x + 1

foo(2)

def foo(x)
    x + 2

foo(2)

extern sin(x)
extern cos(x)

sin(1.0)

def foo(x)
    sin(x) * sin(x) + cos(x) * cos(x)

foo(4)
foo(3)

得到输出:

parsed a function definition
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 1.000000e+00
  ret double %addtmp
}

parsed a top level expr
define double @__anon_expr() {
entry:
  %calltmp = call double @foo(double 2.000000e+00)
  ret double %calltmp
}

3
parsed a function definition
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 2.000000e+00
  ret double %addtmp
}

parsed a top level expr
define double @__anon_expr() {
entry:
  %calltmp = call double @foo(double 2.000000e+00)
  ret double %calltmp
}

4
parsed a extern
declare double @sin(double)

parsed a extern
declare double @cos(double)

parsed a top level expr
define double @__anon_expr() {
entry:
  %calltmp = call double @sin(double 1.000000e+00)
  ret double %calltmp
}

0.841471
parsed a function definition
define double @foo(double %x) {
entry:
  %calltmp = call double @sin(double %x)
  %calltmp1 = call double @sin(double %x)
  %multmp = fmul double %calltmp, %calltmp1
  %calltmp2 = call double @cos(double %x)
  %calltmp3 = call double @cos(double %x)
  %multmp4 = fmul double %calltmp2, %calltmp3
  %addtmp = fadd double %multmp, %multmp4
  ret double %addtmp
}

parsed a top level expr
define double @__anon_expr() {
entry:
  %calltmp = call double @foo(double 4.000000e+00)
  ret double %calltmp
}

1
parsed a top level expr
define double @__anon_expr() {
entry:
  %calltmp = call double @foo(double 3.000000e+00)
  ret double %calltmp
}

1

成功运行,执行正确! 代码可以正确解析 sin, cos 的原因在 KaleidoscopeJIT.h 中,截取其寻找符号的代码。

  JITSymbol findMangledSymbol(const std::string &Name) {
#ifdef _WIN32
    // The symbol lookup of ObjectLinkingLayer uses the SymbolRef::SF_Exported
    // flag to decide whether a symbol will be visible or not, when we call
    // IRCompileLayer::findSymbolIn with ExportedSymbolsOnly set to true.
    //
    // But for Windows COFF objects, this flag is currently never set.
    // For a potential solution see: https://reviews.llvm.org/rL258665
    // For now, we allow non-exported symbols on Windows as a workaround.
    const bool ExportedSymbolsOnly = false;
#else
    const bool ExportedSymbolsOnly = true;
#endif

    // Search modules in reverse order: from last added to first added.
    // This is the opposite of the usual search order for dlsym, but makes more
    // sense in a REPL where we want to bind to the newest available definition.
    for (auto H : make_range(ModuleKeys.rbegin(), ModuleKeys.rend()))
      if (auto Sym = CompileLayer.findSymbolIn(H, Name, ExportedSymbolsOnly))
        return Sym;

    // If we can't find the symbol in the JIT, try looking in the host process.
    if (auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name))
      return JITSymbol(SymAddr, JITSymbolFlags::Exported);

#ifdef _WIN32
    // For Windows retry without "_" at beginning, as RTDyldMemoryManager uses
    // GetProcAddress and standard libraries like msvcrt.dll use names
    // with and without "_" (for example "_itoa" but "sin").
    if (Name.length() > 2 && Name[0] == '_')
      if (auto SymAddr =
              RTDyldMemoryManager::getSymbolAddressInProcess(Name.substr(1)))
        return JITSymbol(SymAddr, JITSymbolFlags::Exported);
#endif

    return null

可以看到,在之前定义的 Module 找不到后会在 host process 中寻找这个符号。

7. SSA

继续给我们的 Kaleidoscope 添加功能之前,需要先介绍 SSA, Static Single Assignment,考虑下面代码:

y := 1
y := 2
x := y

我们可以发现第一个赋值是不必须的,而且第三行使用的 y 来自第二行的赋值,改成 SSA 格式为

y_1 = 1
y_2 = 2
x_1 = y_2

改完可以方便编译器进行优化,比如把第一个赋值删去,于是我们可以给出 SSA 的定义:

  • 每个变量仅且必须被赋值一次,原本代码中的多次变量赋值会被赋予版本号然后视为不同变量;
  • 每个变量在被使用之前必须被定义。

考虑如下 Control Flow Graph:

加上版本号:

可以看到,这里遇到一个问题,最下面的 block 里面的 y 应该使用 y1 还是 y2, 为了解决这个问题,插入一个特殊语句称为 phi function, 其会根据 control flow 从 y1 和 y2 中选择一个值作为 y3, 如下:

可以看到,对于 x 不需要 phi function, 因为两个分支到最后的都是 x2。

腾讯技术交流群已建立,交流讨论可加QQ 群:160315980(备注腾讯技术) ,微信交流群加:teg_helper


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK