fix #105 EINTR handled too defensively when polling
[awesomized/libmemcached] / src / libmemcached / connect.cc
index 3648a242ec87303fbcb54af056a236e33b3f8d67..90f1b21756a5930299ba405a446af4d79c0d7498 100644 (file)
 */
 
 #include "libmemcached/common.h"
+#include "p9y/poll.hpp"
 
 #include <cassert>
 
-#ifndef SOCK_CLOEXEC
-#  define SOCK_CLOEXEC 0
-#endif
-
-#ifndef SOCK_NONBLOCK
-#  define SOCK_NONBLOCK 0
-#endif
-
-#ifndef FD_CLOEXEC
-#  define FD_CLOEXEC 0
-#endif
-
-#ifndef SO_NOSIGPIPE
-#  define SO_NOSIGPIPE 0
-#endif
-
-#ifndef TCP_NODELAY
-#  define TCP_NODELAY 0
-#endif
-
-#ifndef TCP_KEEPIDLE
-#  define TCP_KEEPIDLE 0
-#endif
-
-static memcached_return_t connect_poll(memcached_instance_st *server, const int connection_error) {
-  struct pollfd fds[1];
-  fds[0].fd = server->fd;
-  fds[0].events = server->events();
-  fds[0].revents = 0;
-
-  size_t loop_max = 5;
-
-  if (server->root->poll_timeout == 0) {
-    return memcached_set_error(
-        *server, MEMCACHED_TIMEOUT, MEMCACHED_AT,
-        memcached_literal_param("The time to wait for a connection to be established was set to "
-                                "zero which produces a timeout to every call to poll()."));
-  }
-
-  while (--loop_max) // Should only loop on cases of ERESTART or EINTR
-  {
-    int number_of;
-    if ((number_of = poll(fds, 1, server->root->connect_timeout)) == -1) {
-      int local_errno = get_socket_errno(); // We cache in case closesocket() modifies errno
-      switch (local_errno) {
-#ifdef __linux__
-      case ERESTART:
-#endif
-      case EINTR:
-        continue;
-
-      case EFAULT:
-      case ENOMEM:
-        return memcached_set_error(*server, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT);
-
-      case EINVAL:
-        return memcached_set_error(
-            *server, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT,
-            memcached_literal_param(
-                "RLIMIT_NOFILE exceeded, or if OSX the timeout value was invalid"));
-
-      default: // This should not happen
-        break;
-      }
-
-      assert_msg(server->fd != INVALID_SOCKET, "poll() was passed an invalid file descriptor");
-      server->reset_socket();
-      server->state = MEMCACHED_SERVER_STATE_NEW;
-
-      return memcached_set_errno(*server, local_errno, MEMCACHED_AT);
-    }
-
-    if (number_of == 0) {
-      if (connection_error == EINPROGRESS) {
-        int err;
-        socklen_t len = sizeof(err);
-        if (getsockopt(server->fd, SOL_SOCKET, SO_ERROR, (char *) &err, &len) == -1) {
-          return memcached_set_errno(
-              *server, errno, MEMCACHED_AT,
-              memcached_literal_param(
-                  "getsockopt() error'ed while looking for error connect_poll(EINPROGRESS)"));
-        }
-
-        // If Zero, my hero, we just fail to a generic MEMCACHED_TIMEOUT error
-        if (err != 0) {
-          return memcached_set_errno(
-              *server, err, MEMCACHED_AT,
-              memcached_literal_param("getsockopt() found the error from poll() after connect() "
-                                      "returned EINPROGRESS."));
-        }
-      }
-
-      return memcached_set_error(*server, MEMCACHED_TIMEOUT, MEMCACHED_AT,
-                                 memcached_literal_param("(number_of == 0)"));
-    }
-
-    assert(number_of == 1);
-
-    if (fds[0].revents & POLLERR or fds[0].revents & POLLHUP or fds[0].revents & POLLNVAL) {
-      int err;
-      socklen_t len = sizeof(err);
-      if (getsockopt(fds[0].fd, SOL_SOCKET, SO_ERROR, (char *) &err, &len) == -1) {
-        return memcached_set_errno(
-            *server, errno, MEMCACHED_AT,
-            memcached_literal_param(
-                "getsockopt() errored while looking up error state from poll()"));
-      }
-
-      // We check the value to see what happened with the socket.
-      if (err == 0) // Should not happen
-      {
-        return MEMCACHED_SUCCESS;
-      }
-      errno = err;
-
-      return memcached_set_errno(
-          *server, err, MEMCACHED_AT,
-          memcached_literal_param("getsockopt() found the error from poll() during connect."));
-    }
-    assert(fds[0].revents & POLLOUT);
-
-    if (fds[0].revents & POLLOUT and connection_error == EINPROGRESS) {
-      int err;
-      socklen_t len = sizeof(err);
-      if (getsockopt(server->fd, SOL_SOCKET, SO_ERROR, (char *) &err, &len) == -1) {
-        return memcached_set_errno(*server, errno, MEMCACHED_AT);
-      }
-
-      if (err == 0) {
-        return MEMCACHED_SUCCESS;
-      }
-
-      return memcached_set_errno(
-          *server, err, MEMCACHED_AT,
-          memcached_literal_param(
-              "getsockopt() found the error from poll() after connect() returned EINPROGRESS."));
-    }
-
-    break; // We only have the loop setup for errno types that require restart
-  }
-
-  // This should only be possible from ERESTART or EINTR;
-  return memcached_set_errno(*server, connection_error, MEMCACHED_AT,
-                             memcached_literal_param("connect_poll() was exhausted"));
-}
 
 static memcached_return_t set_hostinfo(memcached_instance_st *server) {
   assert(server->type != MEMCACHED_CONNECTION_UNIX_SOCKET);
@@ -170,7 +26,7 @@ static memcached_return_t set_hostinfo(memcached_instance_st *server) {
   char str_port[MEMCACHED_NI_MAXSERV] = {0};
   errno = 0;
   int length = snprintf(str_port, MEMCACHED_NI_MAXSERV, "%u", uint32_t(server->port()));
-  if (length >= MEMCACHED_NI_MAXSERV or length <= 0 or errno != 0) {
+  if (length >= MEMCACHED_NI_MAXSERV or length <= 0 or errno) {
     return memcached_set_error(*server, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT,
                                memcached_literal_param("snprintf(NI_MAXSERV)"));
   }
@@ -263,7 +119,7 @@ static bool set_socket_options(memcached_instance_st *server) {
 #ifdef HAVE_FCNTL
   // If SOCK_CLOEXEC exists then we don't need to call the following
   if (SOCK_CLOEXEC == 0) {
-    if (FD_CLOEXEC != 0) {
+    if (FD_CLOEXEC) {
       int flags;
       do {
         flags = fcntl(server->fd, F_GETFD, 0);
@@ -397,11 +253,11 @@ static memcached_return_t unix_socket_connect(memcached_instance_st *server) {
 
   do {
     int type = SOCK_STREAM;
-    if (SOCK_CLOEXEC != 0) {
+    if (SOCK_CLOEXEC) {
       type |= SOCK_CLOEXEC;
     }
 
-    if (SOCK_NONBLOCK != 0) {
+    if (SOCK_NONBLOCK) {
       type |= SOCK_NONBLOCK;
     }
 
@@ -482,11 +338,11 @@ static memcached_return_t network_connect(memcached_instance_st *server) {
   /* Create the socket */
   while (server->address_info_next and server->fd == INVALID_SOCKET) {
     int type = server->address_info_next->ai_socktype;
-    if (SOCK_CLOEXEC != 0) {
+    if (SOCK_CLOEXEC) {
       type |= SOCK_CLOEXEC;
     }
 
-    if (SOCK_NONBLOCK != 0) {
+    if (SOCK_NONBLOCK) {
       type |= SOCK_NONBLOCK;
     }
 
@@ -521,12 +377,13 @@ static memcached_return_t network_connect(memcached_instance_st *server) {
 #if EWOULDBLOCK != EAGAIN
     case EWOULDBLOCK:
 #endif
+    case EAGAIN:
     case EINPROGRESS: // nonblocking mode - first return
     case EALREADY:    // nonblocking mode - subsequent returns
     {
       server->events(POLLOUT);
       server->state = MEMCACHED_SERVER_STATE_IN_PROGRESS;
-      memcached_return_t rc = connect_poll(server, local_error);
+      memcached_return_t rc = memcached_io_poll(server, IO_POLL_CONNECT, local_error);
 
       if (memcached_success(rc)) {
         server->state = MEMCACHED_SERVER_STATE_CONNECTED;