X-Git-Url: https://git.m6w6.name/?a=blobdiff_plain;f=libmemcached%2Fsasl.cc;h=5ec8bee2fd454cd09184c2aabcba920a5c6bf94c;hb=24d604ebd655a7e3afe584d3dc4a25f5eca372c3;hp=f1de41eaec4fe30f86416da7714f625e169c0ba6;hpb=27ee6d2aea6210eaca004475600aba78b7170883;p=awesomized%2Flibmemcached diff --git a/libmemcached/sasl.cc b/libmemcached/sasl.cc index f1de41ea..5ec8bee2 100644 --- a/libmemcached/sasl.cc +++ b/libmemcached/sasl.cc @@ -2,7 +2,7 @@ * * Libmemcached library * - * Copyright (C) 2011 Data Differential, http://datadifferential.com/ + * Copyright (C) 2011-2012 Data Differential, http://datadifferential.com/ * Copyright (C) 2006-2009 Brian Aker All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -35,24 +35,39 @@ * */ -#include +#include "libmemcached/common.h" #include #if defined(LIBMEMCACHED_WITH_SASL_SUPPORT) && LIBMEMCACHED_WITH_SASL_SUPPORT +#if defined(HAVE_LIBSASL) && HAVE_LIBSASL #include +#endif + +#define CAST_SASL_CB(cb) reinterpret_cast(reinterpret_cast(cb)) + #include -void memcached_set_sasl_callbacks(memcached_st *ptr, +void memcached_set_sasl_callbacks(memcached_st *shell, const sasl_callback_t *callbacks) { - ptr->sasl.callbacks= const_cast(callbacks); - ptr->sasl.is_allocated= false; + Memcached* self= memcached2Memcached(shell); + if (self) + { + self->sasl.callbacks= const_cast(callbacks); + self->sasl.is_allocated= false; + } } -sasl_callback_t *memcached_get_sasl_callbacks(memcached_st *ptr) +sasl_callback_t *memcached_get_sasl_callbacks(memcached_st *shell) { - return ptr->sasl.callbacks; + Memcached* self= memcached2Memcached(shell); + if (self) + { + return self->sasl.callbacks; + } + + return NULL; } /** @@ -62,21 +77,21 @@ sasl_callback_t *memcached_get_sasl_callbacks(memcached_st *ptr) * @param raddr remote address (out) * @return true on success false otherwise (errno contains more info) */ -static memcached_return_t resolve_names(memcached_server_st& server, char *laddr, size_t laddr_length, char *raddr, size_t raddr_length) +static memcached_return_t resolve_names(memcached_instance_st& server, char *laddr, size_t laddr_length, char *raddr, size_t raddr_length) { - char host[NI_MAXHOST]; - char port[NI_MAXSERV]; + char host[MEMCACHED_NI_MAXHOST]; + char port[MEMCACHED_NI_MAXSERV]; struct sockaddr_storage saddr; socklen_t salen= sizeof(saddr); if (getsockname(server.fd, (struct sockaddr *)&saddr, &salen) < 0) { - return memcached_set_errno(server, MEMCACHED_ERRNO, MEMCACHED_AT); + return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT); } if (getnameinfo((struct sockaddr *)&saddr, salen, host, sizeof(host), port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV) < 0) { - return MEMCACHED_HOST_LOOKUP_FAILURE; + return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT); } (void)snprintf(laddr, laddr_length, "%s;%s", host, port); @@ -84,7 +99,7 @@ static memcached_return_t resolve_names(memcached_server_st& server, char *laddr if (getpeername(server.fd, (struct sockaddr *)&saddr, &salen) < 0) { - return memcached_set_errno(server, MEMCACHED_ERRNO, MEMCACHED_AT); + return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT); } if (getnameinfo((struct sockaddr *)&saddr, salen, host, sizeof(host), @@ -98,12 +113,15 @@ static memcached_return_t resolve_names(memcached_server_st& server, char *laddr return MEMCACHED_SUCCESS; } +extern "C" { + static void sasl_shutdown_function() { sasl_done(); } -static int sasl_startup_state= SASL_OK; +static volatile int sasl_startup_state= SASL_OK; +pthread_mutex_t sasl_startup_state_LOCK= PTHREAD_MUTEX_INITIALIZER; static pthread_once_t sasl_startup_once= PTHREAD_ONCE_INIT; static void sasl_startup_function(void) { @@ -115,7 +133,9 @@ static void sasl_startup_function(void) } } -memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *server) +} // extern "C" + +memcached_return_t memcached_sasl_authenticate_connection(memcached_instance_st* server) { if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) { @@ -128,9 +148,10 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s } /* SANITY CHECK: SASL can only be used with the binary protocol */ - if (server->root->flags.binary_protocol == false) + if (memcached_is_binary(server->root) == false) { - return MEMCACHED_PROTOCOL_ERROR; + return memcached_set_error(*server, MEMCACHED_INVALID_ARGUMENTS, MEMCACHED_AT, + memcached_literal_param("memcached_sasl_authenticate_connection() is not supported via the ASCII protocol")); } /* Try to get the supported mech from the server. Servers without SASL @@ -138,14 +159,16 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s * as authenticated */ protocol_binary_request_no_extras request= { }; - request.message.header.request.magic= PROTOCOL_BINARY_REQ; + + initialize_binary_request(server, request.message.header); + request.message.header.request.opcode= PROTOCOL_BINARY_CMD_SASL_LIST_MECHS; - if (memcached_io_write(server, request.bytes, - sizeof(request.bytes), 1) != sizeof(request.bytes)) + if (memcached_io_write(server, request.bytes, sizeof(request.bytes), true) != sizeof(request.bytes)) { return MEMCACHED_WRITE_FAILURE; } + assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket"); memcached_server_response_increment(server); @@ -166,10 +189,11 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s return rc; } + assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket"); /* set ip addresses */ - char laddr[NI_MAXHOST + NI_MAXSERV]; - char raddr[NI_MAXHOST + NI_MAXSERV]; + char laddr[MEMCACHED_NI_MAXHOST + MEMCACHED_NI_MAXSERV]; + char raddr[MEMCACHED_NI_MAXHOST + MEMCACHED_NI_MAXSERV]; if (memcached_failed(rc= resolve_names(*server, laddr, sizeof(laddr), raddr, sizeof(raddr)))) { @@ -182,16 +206,18 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s return memcached_set_errno(*server, pthread_error, MEMCACHED_AT); } + (void)pthread_mutex_lock(&sasl_startup_state_LOCK); if (sasl_startup_state != SASL_OK) { const char *sasl_error_msg= sasl_errstring(sasl_startup_state, NULL, NULL); return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, memcached_string_make_from_cstr(sasl_error_msg)); } + (void)pthread_mutex_unlock(&sasl_startup_state_LOCK); sasl_conn_t *conn; int ret; - if ((ret= sasl_client_new("memcached", server->hostname, laddr, raddr, server->root->sasl.callbacks, 0, &conn) ) != SASL_OK) + if ((ret= sasl_client_new("memcached", server->_hostname, laddr, raddr, server->root->sasl.callbacks, 0, &conn) ) != SASL_OK) { const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL); @@ -222,26 +248,30 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s do { /* send the packet */ - struct libmemcached_io_vector_st vector[]= + libmemcached_io_vector_st vector[]= { - { sizeof(request.bytes), request.bytes }, - { keylen, chosenmech }, - { len, data } + { request.bytes, sizeof(request.bytes) }, + { chosenmech, keylen }, + { data, len } }; - if (memcached_io_writev(server, vector, 3, true) == -1) + assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket"); + if (memcached_io_writev(server, vector, 3, true) == false) { rc= MEMCACHED_WRITE_FAILURE; break; } + assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket"); memcached_server_response_increment(server); /* read the response */ + assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket"); rc= memcached_response(server, NULL, 0, NULL); if (rc != MEMCACHED_AUTH_CONTINUE) { break; } + assert_msg(server->fd != INVALID_SOCKET, "Programmer error, invalid socket"); ret= sasl_client_step(conn, memcached_result_value(&server->root->result), (unsigned int)memcached_result_length(&server->root->result), @@ -292,10 +322,11 @@ static int get_password(sasl_conn_t *conn, void *context, int id, return SASL_OK; } -memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr, +memcached_return_t memcached_set_sasl_auth_data(memcached_st *shell, const char *username, const char *password) { + Memcached* ptr= memcached2Memcached(shell); if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) { return MEMCACHED_NOT_SUPPORTED; @@ -314,7 +345,7 @@ memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr, memcached_destroy_sasl_auth_data(ptr); - sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(ptr, 4, sizeof(sasl_callback_t)); + sasl_callback_t *callbacks= libmemcached_xcalloc(ptr, 4, sasl_callback_t); size_t password_length= strlen(password); size_t username_length= strlen(username); char *name= (char *)libmemcached_malloc(ptr, username_length +1); @@ -333,13 +364,13 @@ memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr, secret->data[password_length]= 0; callbacks[0].id= SASL_CB_USER; - callbacks[0].proc= (int (*)())get_username; + callbacks[0].proc= CAST_SASL_CB(get_username); callbacks[0].context= strncpy(name, username, username_length +1); callbacks[1].id= SASL_CB_AUTHNAME; - callbacks[1].proc= (int (*)())get_username; + callbacks[1].proc= CAST_SASL_CB(get_username); callbacks[1].context= name; callbacks[2].id= SASL_CB_PASS; - callbacks[2].proc= (int (*)())get_password; + callbacks[2].proc= CAST_SASL_CB(get_password); callbacks[2].context= secret; callbacks[3].id= SASL_CB_LIST_END; @@ -349,13 +380,14 @@ memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr, return MEMCACHED_SUCCESS; } -memcached_return_t memcached_destroy_sasl_auth_data(memcached_st *ptr) +memcached_return_t memcached_destroy_sasl_auth_data(memcached_st *shell) { if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) { return MEMCACHED_NOT_SUPPORTED; } + Memcached* ptr= memcached2Memcached(shell); if (ptr == NULL) { return MEMCACHED_INVALID_ARGUMENTS; @@ -398,11 +430,11 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st /* Hopefully we are using our own callback mechanisms.. */ if (source->sasl.callbacks[0].id == SASL_CB_USER && - source->sasl.callbacks[0].proc == (int (*)())get_username && + source->sasl.callbacks[0].proc == CAST_SASL_CB(get_username) && source->sasl.callbacks[1].id == SASL_CB_AUTHNAME && - source->sasl.callbacks[1].proc == (int (*)())get_username && + source->sasl.callbacks[1].proc == CAST_SASL_CB(get_username) && source->sasl.callbacks[2].id == SASL_CB_PASS && - source->sasl.callbacks[2].proc == (int (*)())get_password && + source->sasl.callbacks[2].proc == CAST_SASL_CB(get_password) && source->sasl.callbacks[3].id == SASL_CB_LIST_END) { sasl_secret_t *secret= (sasl_secret_t *)source->sasl.callbacks[2].context; @@ -416,7 +448,7 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st * into the list, but if we don't know the ID we don't know how to handle * the context... */ - size_t total= 0; + ptrdiff_t total= 0; while (source->sasl.callbacks[total].id != SASL_CB_LIST_END) { @@ -434,7 +466,7 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st ++total; } - sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(clone, total +1, sizeof(sasl_callback_t)); + sasl_callback_t *callbacks= libmemcached_xcalloc(clone, total +1, sasl_callback_t); if (callbacks == NULL) { return MEMCACHED_MEMORY_ALLOCATION_FAILURE; @@ -442,7 +474,7 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st memcpy(callbacks, source->sasl.callbacks, (total + 1) * sizeof(sasl_callback_t)); /* Now update the context... */ - for (size_t x= 0; x < total; ++x) + for (ptrdiff_t x= 0; x < total; ++x) { if (callbacks[x].id == SASL_CB_USER || callbacks[x].id == SASL_CB_AUTHNAME) { @@ -451,7 +483,7 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st if (callbacks[x].context == NULL) { /* Failed to allocate memory, clean up previously allocated memory */ - for (size_t y= 0; y < x; ++y) + for (ptrdiff_t y= 0; y < x; ++y) { libmemcached_free(clone, clone->sasl.callbacks[y].context); } @@ -468,7 +500,7 @@ memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st if (n == NULL) { /* Failed to allocate memory, clean up previously allocated memory */ - for (size_t y= 0; y < x; ++y) + for (ptrdiff_t y= 0; y < x; ++y) { libmemcached_free(clone, clone->sasl.callbacks[y].context); }