X-Git-Url: https://git.m6w6.name/?a=blobdiff_plain;f=libmemcached%2Fsasl.cc;h=902ccd872cf77ee7ede0903730a31c6c9bd96f3d;hb=331dfbb4650c36cda0a773876251e3eba0766175;hp=4f18ca2a62984bb6e652bf7471144779ccd2413d;hpb=b77f874c7d7ff386d01eeedb44c14d3003354bae;p=m6w6%2Flibmemcached diff --git a/libmemcached/sasl.cc b/libmemcached/sasl.cc index 4f18ca2a..902ccd87 100644 --- a/libmemcached/sasl.cc +++ b/libmemcached/sasl.cc @@ -41,6 +41,7 @@ #if defined(LIBMEMCACHED_WITH_SASL_SUPPORT) && LIBMEMCACHED_WITH_SASL_SUPPORT #include +#include void memcached_set_sasl_callbacks(memcached_st *ptr, const sasl_callback_t *callbacks) @@ -97,6 +98,28 @@ static memcached_return_t resolve_names(memcached_server_st& server, char *laddr return MEMCACHED_SUCCESS; } +extern "C" { + +static void sasl_shutdown_function() +{ + sasl_done(); +} + +static volatile int sasl_startup_state= SASL_OK; +pthread_mutex_t sasl_startup_state_LOCK= PTHREAD_MUTEX_INITIALIZER; +static pthread_once_t sasl_startup_once= PTHREAD_ONCE_INIT; +static void sasl_startup_function(void) +{ + sasl_startup_state= sasl_client_init(NULL); + + if (sasl_startup_state == SASL_OK) + { + (void)atexit(sasl_shutdown_function); + } +} + +} // extern "C" + memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *server) { if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) @@ -110,9 +133,10 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s } /* SANITY CHECK: SASL can only be used with the binary protocol */ - if (server->root->flags.binary_protocol == false) + if (memcached_is_binary(server->root) == false) { - return MEMCACHED_PROTOCOL_ERROR; + return memcached_set_error(*server, MEMCACHED_INVALID_ARGUMENTS, MEMCACHED_AT, + memcached_literal_param("memcached_sasl_authenticate_connection() is not supported via the ASCII protocol")); } /* Try to get the supported mech from the server. Servers without SASL @@ -158,18 +182,29 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s return rc; } - int ret; - if ((ret= sasl_client_init(NULL)) != SASL_OK) + int pthread_error; + if ((pthread_error= pthread_once(&sasl_startup_once, sasl_startup_function)) != 0) { - const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL); + return memcached_set_errno(*server, pthread_error, MEMCACHED_AT); + } + + (void)pthread_mutex_lock(&sasl_startup_state_LOCK); + if (sasl_startup_state != SASL_OK) + { + const char *sasl_error_msg= sasl_errstring(sasl_startup_state, NULL, NULL); return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, memcached_string_make_from_cstr(sasl_error_msg)); } + (void)pthread_mutex_unlock(&sasl_startup_state_LOCK); sasl_conn_t *conn; + int ret; if ((ret= sasl_client_new("memcached", server->hostname, laddr, raddr, server->root->sasl.callbacks, 0, &conn) ) != SASL_OK) { const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL); + + sasl_dispose(&conn); + return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, memcached_string_make_from_cstr(sasl_error_msg)); } @@ -181,6 +216,9 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s if (ret != SASL_OK and ret != SASL_CONTINUE) { const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL); + + sasl_dispose(&conn); + return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, memcached_string_make_from_cstr(sasl_error_msg)); } @@ -192,11 +230,11 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s do { /* send the packet */ - struct libmemcached_io_vector_st vector[]= + libmemcached_io_vector_st vector[]= { - { sizeof(request.bytes), request.bytes }, - { keylen, chosenmech }, - { len, data } + { request.bytes, sizeof(request.bytes) }, + { chosenmech, keylen }, + { data, len } }; if (memcached_io_writev(server, vector, 3, true) == -1) @@ -276,9 +314,15 @@ memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr, return MEMCACHED_INVALID_ARGUMENTS; } + memcached_return_t ret; + if (memcached_failed(ret= memcached_behavior_set(ptr, MEMCACHED_BEHAVIOR_BINARY_PROTOCOL, 1))) + { + return memcached_set_error(*ptr, ret, MEMCACHED_AT, memcached_literal_param("Unable change to binary protocol which is required for SASL.")); + } + memcached_destroy_sasl_auth_data(ptr); - sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(ptr, 4, sizeof(sasl_callback_t)); + sasl_callback_t *callbacks= libmemcached_xcalloc(ptr, 4, sasl_callback_t); size_t password_length= strlen(password); size_t username_length= strlen(username); char *name= (char *)libmemcached_malloc(ptr, username_length +1); @@ -398,7 +442,7 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st ++total; } - sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(clone, total +1, sizeof(sasl_callback_t)); + sasl_callback_t *callbacks= libmemcached_xcalloc(clone, total +1, sasl_callback_t); if (callbacks == NULL) { return MEMCACHED_MEMORY_ALLOCATION_FAILURE;