1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
| Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, const Context& default_ctx, const std::map<std::string, Context>& ctx_map, const std::vector<NDArray>& in_args, const std::vector<NDArray>& arg_grad_store, const std::vector<OpReqType>& grad_req_type, const std::vector<NDArray>& aux_states, const nnvm::NodeEntryMap<NDArray>& feed_dict) { nnvm::Graph g = InitFullGraph(symbol, grad_req_type, arg_grad_store); g = AssignContext(g, default_ctx, ctx_map, in_args, grad_store_, aux_states, num_forward_inputs_, num_forward_outputs_); const auto& idx = g.indexed_graph(); num_forward_nodes_ = 0; for (size_t i = 0; i < num_forward_outputs_; ++i) { num_forward_nodes_ = std::max( num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1)); } data_entry_.resize(idx.num_node_entries()); auto mutable_nodes = idx.mutable_input_nodes(); nnvm::ShapeVector arg_shapes; nnvm::DTypeVector arg_types; size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); if (mutable_nodes.count(nid)) { CHECK_LT(aux_top, aux_states.size()); data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; arg_shapes.push_back(aux_states[aux_top].shape()); arg_types.push_back(aux_states[aux_top].dtype()); ++aux_top; } else { CHECK_LT(arg_top, in_args.size()); data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; arg_shapes.push_back(in_args[arg_top].shape()); arg_types.push_back(in_args[arg_top].dtype()); ++arg_top; } } for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second; } arg_shapes.resize(idx.input_nodes().size(), TShape()); arg_types.resize(idx.input_nodes().size(), -1); g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); g = nnvm::pass::InferType(g, arg_types, "__dtype__");
{ const int kBadStorageID = -1; const int kExternalStorageID = -2; nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; } for (const auto& kv : feed_dict) { uint32_t eid = idx.entry_id(kv.first); data_entry_[eid] = kv.second; arg_storage_id[eid] = kExternalStorageID; } g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(arg_storage_id)); g = nnvm::ApplyPass(g, "PlanMemory"); } g = DetectInplaceAddTo(g); return g; }
|