Merge in fixes for SASL.
[m6w6/libmemcached] / libmemcached / connect.cc
index 35053d60cb3c34d1931b2410d9eeebc294977615..3c12339c4a7f16315fc134e88aeff8d5840c63a8 100644 (file)
@@ -37,7 +37,6 @@
 
 
 #include <libmemcached/common.h>
-#include <cassert>
 #include <ctime>
 #include <sys/time.h>
 
@@ -47,13 +46,16 @@ static memcached_return_t connect_poll(memcached_server_st *ptr)
   fds[0].fd = ptr->fd;
   fds[0].events = POLLOUT;
 
-  int error;
   size_t loop_max= 5;
 
-  while (--loop_max) // Should only loop on cases of ERESTART or EINTR
+  if (ptr->root->poll_timeout == 0)
   {
-    error= poll(fds, 1, ptr->root->connect_timeout);
+    return memcached_set_error(*ptr, MEMCACHED_TIMEOUT, MEMCACHED_AT);
+  }
 
+  while (--loop_max) // Should only loop on cases of ERESTART or EINTR
+  {
+    int error= poll(fds, 1, ptr->root->connect_timeout);
     switch (error)
     {
     case 1:
@@ -67,17 +69,15 @@ static memcached_return_t connect_poll(memcached_server_st *ptr)
         {
           return MEMCACHED_SUCCESS;
         }
-        else
-        {
-          ptr->cached_errno= errno;
 
-          return MEMCACHED_ERRNO;
-        }
+        return memcached_set_errno(*ptr, err, MEMCACHED_AT);
       }
     case 0:
-      return MEMCACHED_TIMEOUT;
+      {
+        return memcached_set_error(*ptr, MEMCACHED_TIMEOUT, MEMCACHED_AT);
+      }
+
     default: // A real error occurred and we need to completely bail
-      WATCHPOINT_ERRNO(get_socket_errno());
       switch (get_socket_errno())
       {
 #ifdef TARGET_OS_LINUX
@@ -85,38 +85,43 @@ static memcached_return_t connect_poll(memcached_server_st *ptr)
 #endif
       case EINTR:
         continue;
-      default:
+
+      case EFAULT:
+      case ENOMEM:
+        return memcached_set_error(*ptr, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT);
+
+      case EINVAL:
+        return memcached_set_error(*ptr, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT, memcached_literal_param("RLIMIT_NOFILE exceeded, or if OSX the timeout value was invalid"));
+
+      default: // This should not happen
         if (fds[0].revents & POLLERR)
         {
           int err;
           socklen_t len= sizeof (err);
           (void)getsockopt(ptr->fd, SOL_SOCKET, SO_ERROR, &err, &len);
-          ptr->cached_errno= (err == 0) ? get_socket_errno() : err;
+          memcached_set_errno(*ptr, (err == 0) ? get_socket_errno() : err, MEMCACHED_AT);
         }
         else
         {
-          ptr->cached_errno= get_socket_errno();
+          memcached_set_errno(*ptr, get_socket_errno(), MEMCACHED_AT);
         }
 
+        assert_msg(ptr->fd != INVALID_SOCKET, "poll() was passed an invalid file descriptor");
         (void)closesocket(ptr->fd);
         ptr->fd= INVALID_SOCKET;
+        ptr->state= MEMCACHED_SERVER_STATE_NEW;
 
-        return MEMCACHED_ERRNO;
+        return memcached_set_errno(*ptr, get_socket_errno(), MEMCACHED_AT);
       }
     }
   }
 
   // This should only be possible from ERESTART or EINTR;
-  ptr->cached_errno= get_socket_errno();
-
-  return MEMCACHED_ERRNO;
+  return memcached_set_errno(*ptr, get_socket_errno(), MEMCACHED_AT);
 }
 
 static memcached_return_t set_hostinfo(memcached_server_st *server)
 {
-  char str_port[NI_MAXSERV];
-
-  assert(! server->address_info); // We cover the case where a programming mistake has been made.
   if (server->address_info)
   {
     freeaddrinfo(server->address_info);
@@ -124,9 +129,12 @@ static memcached_return_t set_hostinfo(memcached_server_st *server)
     server->address_info_next= NULL;
   }
 
+  char str_port[NI_MAXSERV];
   int length= snprintf(str_port, NI_MAXSERV, "%u", (uint32_t)server->port);
-  if (length >= NI_MAXSERV || length < 0)
+  if (length >= NI_MAXSERV or length < 0)
+  {
     return MEMCACHED_FAILURE;
+  }
 
   struct addrinfo hints;
   memset(&hints, 0, sizeof(struct addrinfo));
@@ -161,27 +169,26 @@ static memcached_return_t set_hostinfo(memcached_server_st *server)
     return memcached_set_error(*server, MEMCACHED_INVALID_ARGUMENTS, MEMCACHED_AT, memcached_literal_param("getaddrinfo(EAI_BADFLAGS)"));
 
   case EAI_MEMORY:
-    return memcached_set_error(*server, MEMCACHED_ERRNO, MEMCACHED_AT, memcached_literal_param("getaddrinfo(EAI_MEMORY)"));
+    return memcached_set_error(*server, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT, memcached_literal_param("getaddrinfo(EAI_MEMORY)"));
 
   default:
     {
-      WATCHPOINT_STRING(server->hostname);
-      WATCHPOINT_STRING(gai_strerror(e));
       return memcached_set_error(*server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT, memcached_string_make_from_cstr(gai_strerror(errcode)));
     }
   }
   server->address_info_next= server->address_info;
+  server->state= MEMCACHED_SERVER_STATE_ADDRINFO;
 
   return MEMCACHED_SUCCESS;
 }
 
-static inline memcached_return_t set_socket_nonblocking(memcached_server_st *ptr)
+static inline void set_socket_nonblocking(memcached_server_st *ptr)
 {
 #ifdef WIN32
   u_long arg = 1;
   if (ioctlsocket(ptr->fd, FIONBIO, &arg) == SOCKET_ERROR)
   {
-    return memcached_set_errno(*ptr, get_socket_errno(), NULL);
+    memcached_set_errno(*ptr, get_socket_errno(), NULL);
   }
 #else
   int flags;
@@ -189,12 +196,11 @@ static inline memcached_return_t set_socket_nonblocking(memcached_server_st *ptr
   do
   {
     flags= fcntl(ptr->fd, F_GETFL, 0);
-  }
-  while (flags == -1 && (errno == EINTR || errno == EAGAIN));
+  } while (flags == -1 && (errno == EINTR || errno == EAGAIN));
 
-  unlikely (flags == -1)
+  if (flags == -1)
   {
-    return memcached_set_errno(*ptr, errno, NULL);
+    memcached_set_errno(*ptr, errno, NULL);
   }
   else if ((flags & O_NONBLOCK) == 0)
   {
@@ -203,24 +209,24 @@ static inline memcached_return_t set_socket_nonblocking(memcached_server_st *ptr
     do
     {
       rval= fcntl(ptr->fd, F_SETFL, flags | O_NONBLOCK);
-    }
-    while (rval == -1 && (errno == EINTR || errno == EAGAIN));
+    } while (rval == -1 && (errno == EINTR || errno == EAGAIN));
 
     unlikely (rval == -1)
     {
-      return memcached_set_errno(*ptr, errno, NULL);
+      memcached_set_errno(*ptr, errno, NULL);
     }
   }
 #endif
-  return MEMCACHED_SUCCESS;
 }
 
-static memcached_return_t set_socket_options(memcached_server_st *ptr)
+static void set_socket_options(memcached_server_st *ptr)
 {
-  WATCHPOINT_ASSERT(ptr->fd != -1);
+  assert_msg(ptr->fd != -1, "invalid socket was passed to set_socket_options()");
 
   if (ptr->type == MEMCACHED_CONNECTION_UDP)
-    return MEMCACHED_SUCCESS;
+  {
+    return;
+  }
 
 #ifdef HAVE_SNDTIMEO
   if (ptr->root->snd_timeout)
@@ -234,8 +240,6 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, SOL_SOCKET, SO_SNDTIMEO,
                       &waittime, (socklen_t)sizeof(struct timeval));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 #endif
 
@@ -251,15 +255,13 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, SOL_SOCKET, SO_RCVTIMEO,
                       &waittime, (socklen_t)sizeof(struct timeval));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 #endif
 
 
 #if defined(__MACH__) && defined(__APPLE__) || defined(__FreeBSD__)
   {
-    int set = 1;
+    int set= 1;
     int error= setsockopt(ptr->fd, SOL_SOCKET, SO_NOSIGPIPE, (void *)&set, sizeof(int));
 
     // This is not considered a fatal error
@@ -281,8 +283,6 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, SOL_SOCKET, SO_LINGER,
                       &linger, (socklen_t)sizeof(struct linger));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 
   if (ptr->root->flags.tcp_nodelay)
@@ -293,8 +293,6 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, IPPROTO_TCP, TCP_NODELAY,
                       &flag, (socklen_t)sizeof(int));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 
   if (ptr->root->flags.tcp_keepalive)
@@ -305,8 +303,6 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, SOL_SOCKET, SO_KEEPALIVE,
                       &flag, (socklen_t)sizeof(int));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 
 #ifdef TCP_KEEPIDLE
@@ -317,8 +313,6 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, IPPROTO_TCP, TCP_KEEPIDLE,
                       &ptr->root->tcp_keepidle, (socklen_t)sizeof(int));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 #endif
 
@@ -329,8 +323,6 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, SOL_SOCKET, SO_SNDBUF,
                       &ptr->root->send_size, (socklen_t)sizeof(int));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 
   if (ptr->root->recv_size > 0)
@@ -340,13 +332,11 @@ static memcached_return_t set_socket_options(memcached_server_st *ptr)
     error= setsockopt(ptr->fd, SOL_SOCKET, SO_RCVBUF,
                       &ptr->root->recv_size, (socklen_t)sizeof(int));
     WATCHPOINT_ASSERT(error == 0);
-    if (error)
-      return MEMCACHED_FAILURE;
   }
 
 
   /* libmemcached will always use nonblocking IO to avoid write deadlocks */
-  return set_socket_nonblocking(ptr);
+  set_socket_nonblocking(ptr);
 }
 
 static memcached_return_t unix_socket_connect(memcached_server_st *ptr)
@@ -356,7 +346,8 @@ static memcached_return_t unix_socket_connect(memcached_server_st *ptr)
 
   if ((ptr->fd= socket(AF_UNIX, SOCK_STREAM, 0)) < 0)
   {
-    return memcached_set_errno(*ptr, errno, NULL);
+    memcached_set_errno(*ptr, errno, NULL);
+    return MEMCACHED_CONNECTION_FAILURE;
   }
 
   struct sockaddr_un servAddr;
@@ -365,25 +356,30 @@ static memcached_return_t unix_socket_connect(memcached_server_st *ptr)
   servAddr.sun_family= AF_UNIX;
   strncpy(servAddr.sun_path, ptr->hostname, sizeof(servAddr.sun_path)); /* Copy filename */
 
-test_connect:
-  if (connect(ptr->fd,
-              (struct sockaddr *)&servAddr,
-              sizeof(servAddr)) < 0)
-  {
-    switch (errno)
+  do {
+    if (connect(ptr->fd, (struct sockaddr *)&servAddr, sizeof(servAddr)) < 0)
     {
-    case EINPROGRESS:
-    case EALREADY:
-    case EINTR:
-      goto test_connect;
-    case EISCONN: /* We were spinning waiting on connect */
-      break;
-    default:
-      WATCHPOINT_ERRNO(errno);
-      ptr->cached_errno= errno;
-      return MEMCACHED_ERRNO;
+      switch (errno)
+      {
+      case EINPROGRESS:
+      case EALREADY:
+      case EINTR:
+        continue;
+
+      case EISCONN: /* We were spinning waiting on connect */
+        {
+          WATCHPOINT_ASSERT(0); // Programmer error
+          break;
+        }
+
+      default:
+        WATCHPOINT_ERRNO(errno);
+        memcached_set_errno(*ptr, errno, MEMCACHED_AT);
+        return MEMCACHED_CONNECTION_FAILURE;
+      }
     }
-  }
+  } while (0);
+  ptr->state= MEMCACHED_SERVER_STATE_CONNECTED;
 
   WATCHPOINT_ASSERT(ptr->fd != INVALID_SOCKET);
 
@@ -403,12 +399,15 @@ static memcached_return_t network_connect(memcached_server_st *ptr)
 
   if (not ptr->address_info)
   {
+    WATCHPOINT_ASSERT(ptr->state == MEMCACHED_SERVER_STATE_NEW);
     memcached_return_t rc;
     uint32_t counter= 5;
     while (--counter)
     {
       if ((rc= set_hostinfo(ptr)) != MEMCACHED_TIMEOUT)
+      {
         break;
+      }
 
 #ifndef WIN32
       struct timespec dream, rem;
@@ -441,83 +440,106 @@ static memcached_return_t network_connect(memcached_server_st *ptr)
       return memcached_set_errno(*ptr, get_socket_errno(), NULL);
     }
 
-    (void)set_socket_options(ptr);
+    set_socket_options(ptr);
 
     /* connect to server */
     if ((connect(ptr->fd, ptr->address_info_next->ai_addr, ptr->address_info_next->ai_addrlen) != SOCKET_ERROR))
     {
-      break; // Success
+      ptr->state= MEMCACHED_SERVER_STATE_CONNECTED;
+      return MEMCACHED_SUCCESS;
     }
 
     /* An error occurred */
-    ptr->cached_errno= get_socket_errno();
-    switch (ptr->cached_errno) 
+    switch (get_socket_errno())
     {
+    case ETIMEDOUT:
+      timeout_error_occured= true;
+      break;
+
     case EWOULDBLOCK:
     case EINPROGRESS: // nonblocking mode - first return
     case EALREADY: // nonblocking mode - subsequent returns
       {
-        memcached_return_t rc;
-        rc= connect_poll(ptr);
+        ptr->state= MEMCACHED_SERVER_STATE_IN_PROGRESS;
+        memcached_return_t rc= connect_poll(ptr);
+
+        if (memcached_success(rc))
+        {
+          ptr->state= MEMCACHED_SERVER_STATE_CONNECTED;
+          return MEMCACHED_SUCCESS;
+        }
 
+        // A timeout here is treated as an error, we will not retry
         if (rc == MEMCACHED_TIMEOUT)
+        {
           timeout_error_occured= true;
-
-        if (rc == MEMCACHED_SUCCESS)
-          break;
+        }
       }
+      break;
 
     case EISCONN: // we are connected :-)
+      WATCHPOINT_ASSERT(0); // This is a programmer's error
       break;
 
     case EINTR: // Special case, we retry ai_addr
+      WATCHPOINT_ASSERT(ptr->fd != INVALID_SOCKET);
       (void)closesocket(ptr->fd);
       ptr->fd= INVALID_SOCKET;
       continue;
 
     default:
+      break;
+    }
+
+    WATCHPOINT_ASSERT(ptr->fd != INVALID_SOCKET);
+    (void)closesocket(ptr->fd);
+    ptr->fd= INVALID_SOCKET;
+    ptr->address_info_next= ptr->address_info_next->ai_next;
+  }
+
+  WATCHPOINT_ASSERT(ptr->fd == INVALID_SOCKET);
+
+  if (timeout_error_occured)
+  {
+    if (ptr->fd != INVALID_SOCKET)
+    {
       (void)closesocket(ptr->fd);
       ptr->fd= INVALID_SOCKET;
-      ptr->address_info_next= ptr->address_info_next->ai_next;
-      break;
     }
   }
 
-  if (ptr->fd == INVALID_SOCKET)
+  WATCHPOINT_STRING("Never got a good file descriptor");
+  /* Failed to connect. schedule next retry */
+  if (ptr->root->retry_timeout)
   {
-    WATCHPOINT_STRING("Never got a good file descriptor");
+    struct timeval next_time;
 
-    /* Failed to connect. schedule next retry */
-    if (ptr->root->retry_timeout)
+    if (gettimeofday(&next_time, NULL) == 0)
     {
-      struct timeval next_time;
-
-      if (gettimeofday(&next_time, NULL) == 0)
-        ptr->next_retry= next_time.tv_sec + ptr->root->retry_timeout;
+      ptr->next_retry= next_time.tv_sec + ptr->root->retry_timeout;
     }
+  }
+  
+  if (memcached_has_current_error(*ptr))
+  {
+    return memcached_server_error_return(ptr);
+  }
 
-    if (timeout_error_occured)
-      return MEMCACHED_TIMEOUT;
-
-    return MEMCACHED_ERRNO; /* The last error should be from connect() */
+  if (timeout_error_occured and ptr->state < MEMCACHED_SERVER_STATE_IN_PROGRESS)
+  {
+    return memcached_set_error(*ptr, MEMCACHED_TIMEOUT, MEMCACHED_AT);
   }
 
-  return MEMCACHED_SUCCESS; /* The last error should be from connect() */
+  return memcached_set_error(*ptr, MEMCACHED_CONNECTION_FAILURE, MEMCACHED_AT); /* The last error should be from connect() */
 }
 
-void set_last_disconnected_host(memcached_server_write_instance_st ptr)
+void set_last_disconnected_host(memcached_server_write_instance_st self)
 {
   // const_cast
-  memcached_st *root= (memcached_st *)ptr->root;
+  memcached_st *root= (memcached_st *)self->root;
 
-#if 0
-  WATCHPOINT_STRING(ptr->hostname);
-  WATCHPOINT_NUMBER(ptr->port);
-  WATCHPOINT_ERRNO(ptr->cached_errno);
-#endif
-  if (root->last_disconnected_server)
-    memcached_server_free(root->last_disconnected_server);
-  root->last_disconnected_server= memcached_server_clone(NULL, ptr);
+  memcached_server_free(root->last_disconnected_server);
+  root->last_disconnected_server= memcached_server_clone(NULL, self);
 }
 
 memcached_return_t memcached_connect(memcached_server_write_instance_st ptr)
@@ -525,7 +547,9 @@ memcached_return_t memcached_connect(memcached_server_write_instance_st ptr)
   memcached_return_t rc= MEMCACHED_NO_SERVERS;
 
   if (ptr->fd != INVALID_SOCKET)
+  {
     return MEMCACHED_SUCCESS;
+  }
 
   LIBMEMCACHED_MEMCACHED_CONNECT_START();
 
@@ -543,7 +567,7 @@ memcached_return_t memcached_connect(memcached_server_write_instance_st ptr)
     {
       set_last_disconnected_host(ptr);
 
-      return MEMCACHED_SERVER_MARKED_DEAD;
+      return memcached_set_error(*ptr, MEMCACHED_SERVER_MARKED_DEAD, MEMCACHED_AT);
     }
   }
 
@@ -560,48 +584,49 @@ memcached_return_t memcached_connect(memcached_server_write_instance_st ptr)
       run_distribution((memcached_st *)ptr->root);
     }
 
-    return MEMCACHED_SERVER_MARKED_DEAD;
+    return memcached_set_error(*ptr, MEMCACHED_SERVER_MARKED_DEAD, MEMCACHED_AT);
   }
 
   /* We need to clean up the multi startup piece */
   switch (ptr->type)
   {
-  case MEMCACHED_CONNECTION_UNKNOWN:
-    WATCHPOINT_ASSERT(0);
-    rc= MEMCACHED_NOT_SUPPORTED;
-    break;
   case MEMCACHED_CONNECTION_UDP:
   case MEMCACHED_CONNECTION_TCP:
     rc= network_connect(ptr);
-#ifdef LIBMEMCACHED_WITH_SASL_SUPPORT
-    if (ptr->fd != INVALID_SOCKET && ptr->root->sasl.callbacks)
+    if (LIBMEMCACHED_WITH_SASL_SUPPORT)
     {
-      rc= memcached_sasl_authenticate_connection(ptr);
-      if (memcached_failed(rc))
+      if (ptr->fd != INVALID_SOCKET and ptr->root->sasl.callbacks)
       {
-        (void)closesocket(ptr->fd);
-        ptr->fd= INVALID_SOCKET;
+        rc= memcached_sasl_authenticate_connection(ptr);
+        if (memcached_failed(rc) and ptr->fd != INVALID_SOCKET)
+        {
+          WATCHPOINT_ASSERT(ptr->fd != INVALID_SOCKET);
+          (void)closesocket(ptr->fd);
+          ptr->fd= INVALID_SOCKET;
+        }
       }
     }
-#endif
     break;
+
   case MEMCACHED_CONNECTION_UNIX_SOCKET:
     rc= unix_socket_connect(ptr);
     break;
-  case MEMCACHED_CONNECTION_MAX:
-  default:
-    WATCHPOINT_ASSERT(0);
   }
 
-  if (rc == MEMCACHED_SUCCESS)
+  if (memcached_success(rc))
   {
     ptr->server_failure_counter= 0;
     ptr->next_retry= 0;
   }
+  else if (memcached_has_current_error(*ptr))
+  {
+    ptr->server_failure_counter++;
+    set_last_disconnected_host(ptr);
+  }
   else
   {
+    memcached_set_error(*ptr, rc, MEMCACHED_AT);
     ptr->server_failure_counter++;
-
     set_last_disconnected_host(ptr);
   }