Apache Singa
A General Distributed Deep Learning Library
network.h
1 /************************************************************
2 *
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 *
20 *************************************************************/
21 
22 #ifndef SINGA_COMM_NETWORK_H_
23 #define SINGA_COMM_NETWORK_H_
24 #include "singa/singa_config.h"
25 #ifdef ENABLE_DIST
26 #include <ev.h>
27 #include <thread>
28 #include <unordered_map>
29 #include <map>
30 #include <vector>
31 #include <condition_variable>
32 #include <mutex>
33 #include <atomic>
34 #include <string>
35 #include <netinet/in.h>
36 #include <queue>
37 
38 namespace singa {
39 
40 #define LOCKED 1
41 #define UNLOCKED 0
42 
43 #define SIG_EP 1
44 #define SIG_MSG 2
45 
46 #define CONN_INIT 0
47 #define CONN_PENDING 1
48 #define CONN_EST 2
49 #define CONN_ERROR 3
50 
51 #define MAX_RETRY_CNT 3
52 
53 #define EP_TIMEOUT 5.
54 
55 #define MSG_DATA 0
56 #define MSG_ACK 1
57 
58 class NetworkThread;
59 class EndPoint;
60 class EndPointFactory;
61 
62 class Message {
63 private:
64  uint8_t type_;
65  uint32_t id_;
66  std::size_t msize_ = 0;
67  std::size_t psize_ = 0;
68  std::size_t processed_ = 0;
69  char *msg_ = nullptr;
70  static const int hsize_ =
71  sizeof(id_) + 2 * sizeof(std::size_t) + sizeof(type_);
72  char mdata_[hsize_];
73  friend class NetworkThread;
74  friend class EndPoint;
75 
76 public:
77  Message(int = MSG_DATA, uint32_t = 0);
78  Message(const Message &) = delete;
79  Message(Message &&);
80  ~Message();
81 
82  void setMetadata(const void *, int);
83  void setPayload(const void *, int);
84 
85  std::size_t getMetadata(void **);
86  std::size_t getPayload(void **);
87 
88  std::size_t getSize();
89  void setId(uint32_t);
90 };
91 
92 class EndPoint {
93 private:
94  std::queue<Message *> send_;
95  std::queue<Message *> recv_;
96  std::queue<Message *> to_ack_;
97  std::condition_variable cv_;
98  std::mutex mtx_;
99  struct sockaddr_in addr_;
100  ev_timer timer_;
101  ev_tstamp last_msg_time_;
102  int fd_[2] = { -1, -1 }; // two endpoints simultaneously connect to each other
103  int pfd_ = -1;
104  bool is_socket_loop_ = false;
105  int conn_status_ = CONN_INIT;
106  int pending_cnt_ = 0;
107  int retry_cnt_ = 0;
108  NetworkThread *thread_ = nullptr;
109  EndPoint(NetworkThread *t);
110  ~EndPoint();
111  friend class NetworkThread;
112  friend class EndPointFactory;
113 
114 public:
115  int send(Message *);
116  Message *recv();
117 };
118 
119 class EndPointFactory {
120 private:
121  std::unordered_map<uint32_t, EndPoint *> ip_ep_map_;
122  std::condition_variable map_cv_;
123  std::mutex map_mtx_;
124  NetworkThread *thread_;
125  EndPoint *getEp(uint32_t ip);
126  EndPoint *getOrCreateEp(uint32_t ip);
127  friend class NetworkThread;
128 
129 public:
130  EndPointFactory(NetworkThread *thread) : thread_(thread) {}
131  ~EndPointFactory();
132  EndPoint *getEp(const char *host);
133  void getNewEps(std::vector<EndPoint *> &neps);
134 };
135 
136 class NetworkThread {
137 private:
138  struct ev_loop *loop_;
139  ev_async ep_sig_;
140  ev_async msg_sig_;
141  ev_io socket_watcher_;
142  int port_;
143  int socket_fd_;
144  std::thread *thread_;
145  std::unordered_map<int, ev_io> fd_wwatcher_map_;
146  std::unordered_map<int, ev_io> fd_rwatcher_map_;
147  std::unordered_map<int, EndPoint *> fd_ep_map_;
148  std::map<int, Message> pending_msgs_;
149 
150  void handleConnLost(int, EndPoint *, bool = true);
151  void doWork();
152  int asyncSend(int);
153  void asyncSendPendingMsg(EndPoint *);
154  void afterConnEst(EndPoint *ep, int fd, bool active);
155 
156 public:
157  EndPointFactory *epf_;
158 
159  NetworkThread(int);
160  void notify(int signal);
161 
162  void onRecv(int fd);
163  void onSend(int fd = -1);
164  void onConnEst(int fd);
165  void onNewEp();
166  void onNewConn();
167  void onTimeout(struct ev_timer *timer);
168 };
169 }
170 #endif // ENABLE_DIST
171 #endif
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48