WIP
[m6w6/libmemcached] / test / lib / Connection.cpp
1 #include "Connection.hpp"
2
3 #include <cerrno>
4 #include <sys/poll.h>
5 #if HAVE_UNISTD_H
6 # include <unistd.h>
7 #endif
8
9 #if !(HAVE_SOCK_NONBLOCK && HAVE_SOCK_CLOEXEC)
10 # include <fcntl.h>
11 # define SOCK_NONBLOCK O_NONBLOCK
12 # define SOCK_CLOEXEC O_CLOEXEC
13 #endif
14
15 static inline int socket_ex(int af, int so, int pf, int fl) {
16 #if HAVE_SOCK_NONBLOCK && HAVE_SOCK_CLOEXEC
17 return socket(af, so | fl, pf);
18 #else
19 auto sock = socket(af, so, pf);
20 if (0 <= sock) {
21 if (0 > fcntl(sock, F_SETFL, fl | fcntl(sock, F_GETFL))) {
22 close(sock);
23 sock = -1;
24 }
25 }
26 return sock;
27 #endif
28 }
29
30 Connection::Connection(socket_or_port_t socket_or_port) {
31 if (holds_alternative<string>(socket_or_port)) {
32 const auto path = get<string>(socket_or_port);
33 const auto safe = path.c_str();
34 const auto zlen = path.length() + 1;
35 const auto ulen = sizeof(sockaddr_un) - sizeof(sa_family_t);
36
37 if (zlen >= ulen) {
38 throw invalid_argument(error({"socket(): path too long '", path, "'"}));
39 }
40
41 if (0 > (sock = socket_ex(AF_UNIX, SOCK_STREAM, 0, SOCK_NONBLOCK|SOCK_CLOEXEC))) {
42 throw runtime_error(error({"socket(): ", strerror(errno)}));
43 }
44
45 auto sa = reinterpret_cast<sockaddr_un *>(&addr);
46 sa->sun_family = AF_UNIX;
47 copy(safe, safe + zlen, sa->sun_path);
48
49 size = UNIX;
50
51 } else {
52 if (0 > (sock = socket_ex(AF_INET6, SOCK_STREAM, 0, SOCK_NONBLOCK|SOCK_CLOEXEC))) {
53 throw runtime_error(error({"socket(): ", strerror(errno)}));
54 }
55
56 const auto port = get<int>(socket_or_port);
57 auto sa = reinterpret_cast<struct sockaddr_in6 *>(&addr);
58 sa->sin6_family = AF_INET6;
59 sa->sin6_port = htons(static_cast<unsigned short>(port));
60 sa->sin6_addr = IN6ADDR_LOOPBACK_INIT;
61
62 size = INET6;
63 }
64 }
65
66 Connection::~Connection() {
67 close();
68 }
69
70 void swap(Connection &a, Connection &b) {
71 a.swap(b);
72 }
73
74 void Connection::swap(Connection &conn) {
75 Connection copy(conn);
76 conn.sock = sock;
77 conn.addr = addr;
78 conn.size = size;
79 conn.last_err = last_err;
80 sock = exchange(copy.sock, -1);
81 addr = copy.addr;
82 size = copy.size;
83 last_err = copy.last_err;
84 }
85
86 Connection::Connection(const Connection &conn) {
87 if (conn.sock > -1) {
88 sock = dup(conn.sock);
89 }
90 addr = conn.addr;
91 size = conn.size;
92 last_err = conn.last_err;
93 }
94
95 Connection &Connection::operator=(const Connection &conn) {
96 Connection copy(conn);
97 copy.swap(*this);
98 return *this;
99 }
100
101 Connection::Connection(Connection &&conn) noexcept {
102 close();
103 swap(conn);
104 }
105
106 Connection &Connection::operator=(Connection &&conn) noexcept {
107 Connection copy(move(conn));
108 copy.swap(*this);
109 return *this;
110 }
111
112 void Connection::close() {
113 if (sock > -1) {
114 ::close(sock);
115 sock = -1;
116 last_err = -1;
117 }
118 }
119
120 int Connection::getError() {
121 int err = -1;
122 socklen_t len = sizeof(int);
123 if (sock > -1) {
124 errno = 0;
125 if (0 > getsockopt(sock, SOL_SOCKET, SO_ERROR, &err, &len)) {
126 err = errno;
127 }
128 }
129 last_err = err;
130 return err;
131 }
132
133 int Connection::getLastError() {
134 if (last_err == -1) {
135 return getError();
136 }
137 return last_err;
138 }
139
140 bool Connection::isWritable() {
141 pollfd fd{sock, POLLOUT, 0};
142 if (1 > poll(&fd, 1, 0)) {
143 return false;
144 }
145 if (fd.revents & (POLLNVAL|POLLERR|POLLHUP)) {
146 return false;
147 }
148 return fd.revents & POLLOUT;
149 }
150
151 bool Connection::isOpen() {
152 if (sock > -1){
153 if (isWritable()) {
154 return getError() == 0;
155 } else if (open()) {
156 if (isWritable()) {
157 return getError() == 0;
158 }
159 }
160 }
161 return false;
162 }
163
164 bool Connection::open() {
165 if (connected) {
166 return true;
167 }
168 connect_again:
169 errno = 0;
170 if (0 == ::connect(sock, reinterpret_cast<sockaddr *>(&addr), size)) {
171 connected = true;
172 return true;
173 }
174
175 switch (errno) {
176 case EINTR:
177 goto connect_again;
178 case EISCONN:
179 connected = true;
180 [[fallthrough]];
181 case EAGAIN:
182 case EALREADY:
183 case EINPROGRESS:
184 return true;
185
186 default:
187 return false;
188 }
189 }
190
191 string Connection::error(const initializer_list<string> &args) {
192 stringstream ss;
193
194 for (const auto &arg : args) {
195 ss << arg;
196 }
197
198 return ss.str();
199 }