Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros
graph.h
1 #ifndef INCLUDE_UTILS_GRAPH_H_
2 #define INCLUDE_UTILS_GRAPH_H_
3 #include <glog/logging.h>
4 #include <vector>
5 #include <string>
6 #include <map>
7 #include <stack>
8 #include <memory>
9 
10 using std::vector;
11 using std::string;
12 using std::map;
13 using std::pair;
14 using std::shared_ptr;
15 using std::make_shared;
16 
17 
18 typedef struct _LayerInfo{
19  // origin identifies the origin of this node, i.e., the corresponding layer
20  string origin;
21  //int locationid;// locationidation id;
22  int partitionid;
23  int slice_dimension;
24  int concate_dimension;
25 }LayerInfo;
26 typedef LayerInfo V;
27 
28 
29 class Node;
30 typedef shared_ptr<Node> SNode;
31 
32 class Node{
33  public:
34  typedef shared_ptr<Node> SNode;
35  Node(string name): name_(name){}
36  Node(string name, const V& v):
37  name_(name), val_(v){}
38 
39  void AddDstNode(SNode dstnode){
40  dstnodes_.push_back(dstnode);
41  }
42  void AddSrcNode(SNode srcnode){
43  srcnodes_.push_back(srcnode);
44  }
45 
46  void RemoveDstNode(SNode dst){
47  auto iter=dstnodes_.begin();
48  while((*iter)->name_!=dst->name_&&iter!=dstnodes_.end()) iter++;
49  CHECK((*iter)->name_==dst->name_);
50  dstnodes_.erase(iter);
51  }
52  void RemoveSrcNode(SNode src){
53  auto iter=srcnodes_.begin();
54  while((*iter)->name_!=src->name_&&iter!=srcnodes_.end()) iter++;
55  CHECK((*iter)->name_==src->name_);
56  srcnodes_.erase(iter);
57  }
58  const string& name() const {return name_;}
59  const V& val() const {return val_;}
60  const SNode srcnodes(int k) const {return srcnodes_[k]; }
61  const SNode dstnodes(int k) const {return dstnodes_[k]; }
62  const vector<SNode>& srcnodes() const {return srcnodes_; }
63  const vector<SNode>& dstnodes() const {return dstnodes_; }
64  int dstnodes_size() const {return dstnodes_.size(); }
65  int srcnodes_size() const {return srcnodes_.size(); }
66 
67  private:
68  string name_;
69  vector<SNode> srcnodes_;
70  vector<SNode> dstnodes_;
71 
72  V val_;
73  // properties
74  string color_, weight_, shape_;
75 };
76 
77 
81 class Graph{
82  public:
83  Graph(){}
84  void Sort();
85  const SNode& AddNode(string name, V origin){
86  nodes_.push_back(make_shared<Node>(name, origin));
87  name2node_[name]=nodes_.back();
88  return nodes_.back();
89  }
90  const SNode& AddNode(string name){
91  nodes_.push_back(make_shared<Node>(name));
92  name2node_[name]=nodes_.back();
93  return nodes_.back();
94  }
95 
96  void AddEdge(SNode srcnode, SNode dstnode){
97  srcnode->AddDstNode(dstnode);
98  dstnode->AddSrcNode(srcnode);
99  }
100 
101  void AddEdge(const string& src, const string& dst){
102  CHECK(name2node_.find(src)!=name2node_.end())<<"can't find src node "<<src;
103  CHECK(name2node_.find(dst)!=name2node_.end())<<"can't find dst node "<<dst;
104 
105  SNode srcnode=name2node_[src], dstnode=name2node_[dst];
106  AddEdge(srcnode, dstnode);
107  }
108 
109  void RemoveEdge(const string &src, const string& dst){
110  CHECK(name2node_.find(src)!=name2node_.end())<<"can't find src node "<<src;
111  CHECK(name2node_.find(dst)!=name2node_.end())<<"can't find dst node "<<dst;
112 
113  SNode srcnode=name2node_[src], dstnode=name2node_[dst];
114  RemoveEdge(srcnode, dstnode);
115  }
116 
117  void RemoveEdge(SNode src, SNode dst){
118  src->RemoveDstNode(dst);
119  dst->RemoveSrcNode(src);
120  }
121 
122  const vector<SNode>& nodes() const{
123  return nodes_;
124  };
125 
126  const SNode& node(string name) const{
127  CHECK(name2node_.find(name)!= name2node_.end())
128  <<"can't find dst node "<<name;
129  return name2node_.at(name);
130  }
131 
132  const string ToString() const;
133  const string ToString(const map<string, string>& info) const ;
134 
135  bool Check() const;
136 
137  SNode InsertSliceNode(SNode srcnode, const vector<SNode>& dstnodes,
138  const V& info, bool connect_dst=true);
139  SNode InsertConcateNode(const vector<SNode>&srcnodes, SNode dstnode,
140  const V& info);
141  SNode InsertSplitNode(SNode srcnode, const vector<SNode>& dstnodes);
142  std::pair<SNode, SNode> InsertBridgeNode(SNode srcnode, SNode dstnode);
143  void topology_sort_inner(SNode node, map<string, bool> *visited,
144  std::stack<string> *stack);
145 
146  private:
147  vector<SNode> nodes_;
148  map<string, SNode> name2node_;
149 };
150 #endif // INCLUDE_UTILS_GRAPH_H_
Definition: graph.h:32
For partition neuralnet and displaying the neuralnet structure.
Definition: graph.h:81
Definition: graph.h:18