Rollup from build trunk.
[m6w6/libmemcached] / libmemcached / sasl.cc
index 4f18ca2a62984bb6e652bf7471144779ccd2413d..902ccd872cf77ee7ede0903730a31c6c9bd96f3d 100644 (file)
@@ -41,6 +41,7 @@
 #if defined(LIBMEMCACHED_WITH_SASL_SUPPORT) && LIBMEMCACHED_WITH_SASL_SUPPORT
 
 #include <sasl/sasl.h>
+#include <pthread.h>
 
 void memcached_set_sasl_callbacks(memcached_st *ptr,
                                   const sasl_callback_t *callbacks)
@@ -97,6 +98,28 @@ static memcached_return_t resolve_names(memcached_server_st& server, char *laddr
   return MEMCACHED_SUCCESS;
 }
 
+extern "C" {
+
+static void sasl_shutdown_function()
+{
+  sasl_done();
+}
+
+static volatile int sasl_startup_state= SASL_OK;
+pthread_mutex_t sasl_startup_state_LOCK= PTHREAD_MUTEX_INITIALIZER;
+static pthread_once_t sasl_startup_once= PTHREAD_ONCE_INIT;
+static void sasl_startup_function(void)
+{
+  sasl_startup_state= sasl_client_init(NULL);
+
+  if (sasl_startup_state == SASL_OK)
+  {
+    (void)atexit(sasl_shutdown_function);
+  }
+}
+
+} // extern "C"
+
 memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *server)
 {
   if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0)
@@ -110,9 +133,10 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s
   }
 
   /* SANITY CHECK: SASL can only be used with the binary protocol */
-  if (server->root->flags.binary_protocol == false)
+  if (memcached_is_binary(server->root) == false)
   {
-    return MEMCACHED_PROTOCOL_ERROR;
+    return  memcached_set_error(*server, MEMCACHED_INVALID_ARGUMENTS, MEMCACHED_AT,
+                                memcached_literal_param("memcached_sasl_authenticate_connection() is not supported via the ASCII protocol"));
   }
 
   /* Try to get the supported mech from the server. Servers without SASL
@@ -158,18 +182,29 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s
     return rc;
   }
 
-  int ret;
-  if ((ret= sasl_client_init(NULL)) != SASL_OK)
+  int pthread_error;
+  if ((pthread_error= pthread_once(&sasl_startup_once, sasl_startup_function)) != 0)
   {
-    const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL);
+    return memcached_set_errno(*server, pthread_error, MEMCACHED_AT);
+  }
+
+  (void)pthread_mutex_lock(&sasl_startup_state_LOCK);
+  if (sasl_startup_state != SASL_OK)
+  {
+    const char *sasl_error_msg= sasl_errstring(sasl_startup_state, NULL, NULL);
     return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, 
                                memcached_string_make_from_cstr(sasl_error_msg));
   }
+  (void)pthread_mutex_unlock(&sasl_startup_state_LOCK);
 
   sasl_conn_t *conn;
+  int ret;
   if ((ret= sasl_client_new("memcached", server->hostname, laddr, raddr, server->root->sasl.callbacks, 0, &conn) ) != SASL_OK)
   {
     const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL);
+
+    sasl_dispose(&conn);
+
     return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, 
                                memcached_string_make_from_cstr(sasl_error_msg));
   }
@@ -181,6 +216,9 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s
   if (ret != SASL_OK and ret != SASL_CONTINUE)
   {
     const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL);
+
+    sasl_dispose(&conn);
+
     return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, 
                                memcached_string_make_from_cstr(sasl_error_msg));
   }
@@ -192,11 +230,11 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s
   do {
     /* send the packet */
 
-    struct libmemcached_io_vector_st vector[]=
+    libmemcached_io_vector_st vector[]=
     {
-      { sizeof(request.bytes), request.bytes },
-      { keylen, chosenmech },
-      { len, data }
+      { request.bytes, sizeof(request.bytes) },
+      { chosenmech, keylen },
+      { data, len }
     };
 
     if (memcached_io_writev(server, vector, 3, true) == -1)
@@ -276,9 +314,15 @@ memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr,
     return MEMCACHED_INVALID_ARGUMENTS;
   }
 
+  memcached_return_t ret;
+  if (memcached_failed(ret= memcached_behavior_set(ptr, MEMCACHED_BEHAVIOR_BINARY_PROTOCOL, 1)))
+  {
+    return memcached_set_error(*ptr, ret, MEMCACHED_AT, memcached_literal_param("Unable change to binary protocol which is required for SASL."));
+  }
+
   memcached_destroy_sasl_auth_data(ptr);
 
-  sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(ptr, 4, sizeof(sasl_callback_t));
+  sasl_callback_t *callbacks= libmemcached_xcalloc(ptr, 4, sasl_callback_t);
   size_t password_length= strlen(password);
   size_t username_length= strlen(username);
   char *name= (char *)libmemcached_malloc(ptr, username_length +1);
@@ -398,7 +442,7 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const  memcached_st
     ++total;
   }
 
-  sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(clone, total +1, sizeof(sasl_callback_t));
+  sasl_callback_t *callbacks= libmemcached_xcalloc(clone, total +1, sasl_callback_t);
   if (callbacks == NULL)
   {
     return MEMCACHED_MEMORY_ALLOCATION_FAILURE;