1 #ifndef INCLUDE_UTILS_GRAPH_H_
2 #define INCLUDE_UTILS_GRAPH_H_
3 #include <glog/logging.h>
14 using std::shared_ptr;
15 using std::make_shared;
24 int concate_dimension;
30 typedef shared_ptr<Node> SNode;
34 typedef shared_ptr<Node> SNode;
35 Node(
string name): name_(name){}
36 Node(
string name,
const V& v):
37 name_(name), val_(v){}
39 void AddDstNode(SNode dstnode){
40 dstnodes_.push_back(dstnode);
42 void AddSrcNode(SNode srcnode){
43 srcnodes_.push_back(srcnode);
46 void RemoveDstNode(SNode dst){
47 auto iter=dstnodes_.begin();
48 while((*iter)->name_!=dst->name_&&iter!=dstnodes_.end()) iter++;
49 CHECK((*iter)->name_==dst->name_);
50 dstnodes_.erase(iter);
52 void RemoveSrcNode(SNode src){
53 auto iter=srcnodes_.begin();
54 while((*iter)->name_!=src->name_&&iter!=srcnodes_.end()) iter++;
55 CHECK((*iter)->name_==src->name_);
56 srcnodes_.erase(iter);
58 const string& name()
const {
return name_;}
59 const V& val()
const {
return val_;}
60 const SNode srcnodes(
int k)
const {
return srcnodes_[k]; }
61 const SNode dstnodes(
int k)
const {
return dstnodes_[k]; }
62 const vector<SNode>& srcnodes()
const {
return srcnodes_; }
63 const vector<SNode>& dstnodes()
const {
return dstnodes_; }
64 int dstnodes_size()
const {
return dstnodes_.size(); }
65 int srcnodes_size()
const {
return srcnodes_.size(); }
69 vector<SNode> srcnodes_;
70 vector<SNode> dstnodes_;
74 string color_, weight_, shape_;
85 const SNode& AddNode(
string name,
V origin){
86 nodes_.push_back(make_shared<Node>(name, origin));
87 name2node_[name]=nodes_.back();
90 const SNode& AddNode(
string name){
91 nodes_.push_back(make_shared<Node>(name));
92 name2node_[name]=nodes_.back();
96 void AddEdge(SNode srcnode, SNode dstnode){
97 srcnode->AddDstNode(dstnode);
98 dstnode->AddSrcNode(srcnode);
101 void AddEdge(
const string& src,
const string& dst){
102 CHECK(name2node_.find(src)!=name2node_.end())<<
"can't find src node "<<src;
103 CHECK(name2node_.find(dst)!=name2node_.end())<<
"can't find dst node "<<dst;
105 SNode srcnode=name2node_[src], dstnode=name2node_[dst];
106 AddEdge(srcnode, dstnode);
109 void RemoveEdge(
const string &src,
const string& dst){
110 CHECK(name2node_.find(src)!=name2node_.end())<<
"can't find src node "<<src;
111 CHECK(name2node_.find(dst)!=name2node_.end())<<
"can't find dst node "<<dst;
113 SNode srcnode=name2node_[src], dstnode=name2node_[dst];
114 RemoveEdge(srcnode, dstnode);
117 void RemoveEdge(SNode src, SNode dst){
118 src->RemoveDstNode(dst);
119 dst->RemoveSrcNode(src);
122 const vector<SNode>& nodes()
const{
126 const SNode& node(
string name)
const{
127 CHECK(name2node_.find(name)!= name2node_.end())
128 <<
"can't find dst node "<<name;
129 return name2node_.at(name);
132 const string ToString()
const;
133 const string ToString(
const map<string, string>& info)
const ;
137 SNode InsertSliceNode(SNode srcnode,
const vector<SNode>& dstnodes,
138 const V& info,
bool connect_dst=
true);
139 SNode InsertConcateNode(
const vector<SNode>&srcnodes, SNode dstnode,
141 SNode InsertSplitNode(SNode srcnode,
const vector<SNode>& dstnodes);
142 std::pair<SNode, SNode> InsertBridgeNode(SNode srcnode, SNode dstnode);
143 void topology_sort_inner(SNode node, map<string, bool> *visited,
144 std::stack<string> *stack);
147 vector<SNode> nodes_;
148 map<string, SNode> name2node_;
150 #endif // INCLUDE_UTILS_GRAPH_H_
For partition neuralnet and displaying the neuralnet structure.
Definition: graph.h:81