4

ONNX Runtime 源码阅读:Graph::SetGraphInputsOutputs() 函数 - 虔诚的树

 2 years ago
source link: https://www.cnblogs.com/xxxxxxxxx/p/16219014.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

ONNX Runtime 源码阅读:Graph::SetGraphInputsOutputs() 函数

为了深入理解ONNX Runtime的底层机制,本文将对 Graph::SetGraphInputsOutputs() 的代码逐行分析。

首先判断Graph是否从ONNX文件中加载所得:

if (is_loaded_from_model_file_) return Status::OK();

如果是,可直接返回;如果不是,则需要解析Graph中的节点,从而设置模型的输入和输出。

Graph中的成员变量 value_info_graph_inputs_excluding_initializers_graph_inputs_including_initializers_ 以及 graph_outputs_ 全部清空:

value_info_.clear();

graph_inputs_excluding_initializers_.clear();

if (!graph_inputs_manually_set_) {
  graph_inputs_including_initializers_.clear();
} else {
  std::unordered_set<std::string> existing_names;
  for (auto arg : graph_inputs_including_initializers_) {
    const std::string& name = arg->Name();
    if (existing_names.count(name) == 0) {
      graph_inputs_excluding_initializers_.push_back(arg);
      existing_names.insert(name);
    }
  }
}

if (!graph_outputs_manually_set_) {
  graph_outputs_.clear();
}

设置一些局部变量,方便下面的使用分析:

std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
std::vector<const NodeArg*> output_node_args_in_order;
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};

统计所有节点的输出,添加到以上局部变量(output_name_to_node_arg_index 和 output_node_args_in_order)中:

for (const auto& node : Nodes()) {
  for (const auto* output_def : node.OutputDefs()) {
    if (output_def->Exists()) {
      output_node_args_in_order.push_back(output_def);
      output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1});
    }
  }
}
auto graph_output_args = output_name_to_node_arg_index;  // 拷贝一份输出节点map

然后遍历图中每个节点以及每个节点的输入:

for (const auto& node : Nodes()) {
  // Go thru all node's inputs.
  for (const auto* input_arg : node.InputDefs()) {
    ...
  }
}

在输出节点name列表中查找当前输入name

auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name());

如果没有找到,说明这个节点的输入就是图的输入,接下来还需要判断这个输入是否已经放在局部变量added_input_names中:

if (output_name_to_node_arg_index.end() == output_arg_iter) {
  // This input arg is not the output of another node so must come from either a graph input or an initializer.
  const std::string& name = input_arg->Name();
  if (added_input_names.end() == added_input_names.find(name)) {
    ...
  }
}

如果已经放到局部变量added_input_names中,就可以判断节点的下一个输入或者下一个节点的输入。如果没有放到局部变量added_input_names中:

bool is_initializer = name_to_initial_tensor_.find(name) != name_to_initial_tensor_.end();  // 判断当前input_arg是否已初始化过的tensor,如果是就不可以再放置到 graph_inputs_excluding_initializers_ 中
if (!graph_inputs_manually_set_) {   // 如果未主动调用 SetInputs() 方法
  // if IR version < 4 all initializers must have a matching graph input
  // (even though the graph input is not allowed to override the initializer).
  // if IR version >= 4 initializers are not required to have a matching graph input.
  // any graph inputs that are to override initializers must be specified by calling SetInputs.
  if (!is_initializer || ir_version_ < 4) {
    graph_inputs_including_initializers_.push_back(input_arg);
  }
  if (!is_initializer) {
    // If input_arg is not of an initializer, we add it into graph_inputs_excluding_initializers_.
    graph_inputs_excluding_initializers_.push_back(input_arg);
  }
} else {  // 如果主动调用了 SetInputs() 方法
  // graph_inputs_including_initializers_ has been manually populated by SetInputs.
  // Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
  if (!is_initializer) {
    const auto& inputs = graph_inputs_including_initializers_;
    bool in_inputs = std::find(inputs.begin(), inputs.end(), input_arg) != inputs.end();
    if (!in_inputs) {
      return Status(ONNXRUNTIME, FAIL,
                    name + " must be either specified in graph inputs or graph initializers.");
    }
  } else {
    // If arg_input is of an initializer, we remove it from graph_inputs_excluding_initializers_
    // whose initial content has both initializers and non-initializers.
    auto input_pos = std::find(graph_inputs_excluding_initializers_.begin(),
                                graph_inputs_excluding_initializers_.end(),
                                input_arg);
    if (input_pos != graph_inputs_excluding_initializers_.end()) {
      graph_inputs_excluding_initializers_.erase(input_pos);
    }
  }
}
added_input_names.insert(name);

可以看到,这里会把当前的 input_arg 分别放到 graph_inputs_including_initializers_graph_inputs_excluding_initializers_ 中,并将name放在added_input_names中。

如果该输入的name已经在输出节点name列表中,说明这个节点是中间输出结果,而非整个图的输出,因此应该将其从图的输出(graph_output_args)中删除,并放在 value_info_ 中:

if (output_name_to_node_arg_index.end() == output_arg_iter) {
  ...
}else if(graph_output_args.erase(output_arg_iter->first) >= 1){
  value_info_.insert(input_arg);
}

以上我们对Graph的三个成员变量:graph_inputs_including_initializers_graph_inputs_excluding_initializers_value_info_分别进行了赋值,其中前两者存储输入,后者存储中间结果。我们还需要处理图的输出结果:`graph_outputs_`:

if (!graph_outputs_manually_set_) {
  // Set graph outputs in order.
  std::vector<size_t> graph_output_args_index;
  graph_output_args_index.reserve(graph_output_args.size());
  for (const auto& output_arg : graph_output_args) {          // graph_output_args原本存储了所有节点的输出,但是前面的代码已经把中间节点的输出给移除了,因此剩下的就是整个Graph的输出
    graph_output_args_index.push_back(output_arg.second);
  }

  std::sort(graph_output_args_index.begin(), graph_output_args_index.end());
  for (auto& output_arg_index : graph_output_args_index) {
    graph_outputs_.push_back(output_node_args_in_order[output_arg_index]);
  }
}

最后,还需要对 graph_overridable_initializers_ 进行处理:

ComputeOverridableInitializers();

进入这个函数内部:

void Graph::ComputeOverridableInitializers() {
  graph_overridable_initializers_.clear();
  if (CanOverrideInitializer()) {
    // graph_inputs_excluding_initializers_ and graph_inputs_including_initializers_
    // are inserted in the same order. So we walk and compute the difference.
    auto f_incl = graph_inputs_including_initializers_.cbegin();
    const auto l_incl = graph_inputs_including_initializers_.cend();
    auto f_excl = graph_inputs_excluding_initializers_.cbegin();
    const auto l_excl = graph_inputs_excluding_initializers_.cend();

    while (f_incl != l_incl) {
      // Equal means not an initializer
      if (f_excl != l_excl && *f_incl == *f_excl) {
        ++f_incl;
        ++f_excl;
        continue;
      }
      graph_overridable_initializers_.push_back(*f_incl);
      ++f_incl;
    }
  }
}

这是一个很简单的算法,通过比较 graph_inputs_including_initializers_graph_inputs_excluding_initializers_,提取出 initializer 并放置到 graph_overridable_initializers_ 中。

至此,我们完成了对 Graph::SetGraphInputsOutputs() 函数的解析。

针对这个函数的解析不仅理解了如何从Graph的nodes中分析出graph的输入和输出,而且懂得了graph_overridable_initializers_以及value_info_的作用。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK