+/* vim:expandtab:shiftwidth=2:tabstop=2:smarttab:
+ *
+ * Libmemcached library
+ *
+ * Copyright (C) 2011 Data Differential, http://datadifferential.com/
+ * Copyright (C) 2006-2009 Brian Aker All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *
+ * * The names of its contributors may not be used to endorse or
+ * promote products derived from this software without specific prior
+ * written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+#include <libmemcached/common.h>
+#include <cassert>
+
+#if defined(LIBMEMCACHED_WITH_SASL_SUPPORT) && LIBMEMCACHED_WITH_SASL_SUPPORT
+
+#include <sasl/sasl.h>
+
+void memcached_set_sasl_callbacks(memcached_st *ptr,
+ const sasl_callback_t *callbacks)
+{
+ ptr->sasl.callbacks= const_cast<sasl_callback_t *>(callbacks);
+ ptr->sasl.is_allocated= false;
+}
+
+sasl_callback_t *memcached_get_sasl_callbacks(memcached_st *ptr)
+{
+ return ptr->sasl.callbacks;
+}
+
+/**
+ * Resolve the names for both ends of a connection
+ * @param fd socket to check
+ * @param laddr local address (out)
+ * @param raddr remote address (out)
+ * @return true on success false otherwise (errno contains more info)
+ */
+static memcached_return_t resolve_names(memcached_server_st& server, char *laddr, size_t laddr_length, char *raddr, size_t raddr_length)
+{
+ char host[NI_MAXHOST];
+ char port[NI_MAXSERV];
+ struct sockaddr_storage saddr;
+ socklen_t salen= sizeof(saddr);
+
+ if (getsockname(server.fd, (struct sockaddr *)&saddr, &salen) < 0)
+ {
+ return memcached_set_errno(server, MEMCACHED_ERRNO, MEMCACHED_AT);
+ }
+
+ if (getnameinfo((struct sockaddr *)&saddr, salen, host, sizeof(host), port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV) < 0)
+ {
+ return MEMCACHED_HOST_LOOKUP_FAILURE;
+ }
+
+ (void)snprintf(laddr, laddr_length, "%s;%s", host, port);
+ salen= sizeof(saddr);
+
+ if (getpeername(server.fd, (struct sockaddr *)&saddr, &salen) < 0)
+ {
+ return memcached_set_errno(server, MEMCACHED_ERRNO, MEMCACHED_AT);
+ }
+
+ if (getnameinfo((struct sockaddr *)&saddr, salen, host, sizeof(host),
+ port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV) < 0)
+ {
+ return memcached_set_error(server, MEMCACHED_HOST_LOOKUP_FAILURE, MEMCACHED_AT);
+ }
+
+ (void)snprintf(raddr, raddr_length, "%s;%s", host, port);
+
+ return MEMCACHED_SUCCESS;
+}
+
+memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *server)
+{
+ if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0)
+ {
+ return MEMCACHED_NOT_SUPPORTED;
+ }
+
+ if (server == NULL)
+ {
+ return MEMCACHED_INVALID_ARGUMENTS;
+ }
+
+ /* SANITY CHECK: SASL can only be used with the binary protocol */
+ if (server->root->flags.binary_protocol == false)
+ {
+ return MEMCACHED_PROTOCOL_ERROR;
+ }
+
+ /* Try to get the supported mech from the server. Servers without SASL
+ * support will return UNKNOWN COMMAND, so we can just treat that
+ * as authenticated
+ */
+ protocol_binary_request_no_extras request= { };
+ request.message.header.request.magic= PROTOCOL_BINARY_REQ;
+ request.message.header.request.opcode= PROTOCOL_BINARY_CMD_SASL_LIST_MECHS;
+
+ if (memcached_io_write(server, request.bytes,
+ sizeof(request.bytes), 1) != sizeof(request.bytes))
+ {
+ return MEMCACHED_WRITE_FAILURE;
+ }
+
+ memcached_server_response_increment(server);
+
+ char mech[MEMCACHED_MAX_BUFFER];
+ memcached_return_t rc= memcached_response(server, mech, sizeof(mech), NULL);
+ if (memcached_failed(rc))
+ {
+ if (rc == MEMCACHED_PROTOCOL_ERROR)
+ {
+ /* If the server doesn't support SASL it will return PROTOCOL_ERROR.
+ * This error may also be returned for other errors, but let's assume
+ * that the server don't support SASL and treat it as success and
+ * let the client fail with the next operation if the error was
+ * caused by another problem....
+ */
+ rc= MEMCACHED_SUCCESS;
+ }
+
+ return rc;
+ }
+
+ /* set ip addresses */
+ char laddr[NI_MAXHOST + NI_MAXSERV];
+ char raddr[NI_MAXHOST + NI_MAXSERV];
+
+ if (memcached_failed(rc= resolve_names(*server, laddr, sizeof(laddr), raddr, sizeof(raddr))))
+ {
+ return rc;
+ }
+
+ int ret;
+ if ((ret= sasl_client_init(NULL)) != SASL_OK)
+ {
+ const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL);
+ return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT,
+ memcached_string_make_from_cstr(sasl_error_msg));
+ }
+
+ sasl_conn_t *conn;
+ if ((ret= sasl_client_new("memcached", server->hostname, laddr, raddr, server->root->sasl.callbacks, 0, &conn) ) != SASL_OK)
+ {
+ const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL);
+ return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT,
+ memcached_string_make_from_cstr(sasl_error_msg));
+ }
+
+ const char *data;
+ const char *chosenmech;
+ unsigned int len;
+ ret= sasl_client_start(conn, mech, NULL, &data, &len, &chosenmech);
+ if (ret != SASL_OK and ret != SASL_CONTINUE)
+ {
+ const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL);
+ return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT,
+ memcached_string_make_from_cstr(sasl_error_msg));
+ }
+ uint16_t keylen= (uint16_t)strlen(chosenmech);
+ request.message.header.request.opcode= PROTOCOL_BINARY_CMD_SASL_AUTH;
+ request.message.header.request.keylen= htons(keylen);
+ request.message.header.request.bodylen= htonl(len + keylen);
+
+ do {
+ /* send the packet */
+
+ struct libmemcached_io_vector_st vector[]=
+ {
+ { sizeof(request.bytes), request.bytes },
+ { keylen, chosenmech },
+ { len, data }
+ };
+
+ if (memcached_io_writev(server, vector, 3, true) == -1)
+ {
+ rc= MEMCACHED_WRITE_FAILURE;
+ break;
+ }
+ memcached_server_response_increment(server);
+
+ /* read the response */
+ rc= memcached_response(server, NULL, 0, NULL);
+ if (rc != MEMCACHED_AUTH_CONTINUE)
+ {
+ break;
+ }
+
+ ret= sasl_client_step(conn, memcached_result_value(&server->root->result),
+ (unsigned int)memcached_result_length(&server->root->result),
+ NULL, &data, &len);
+
+ if (ret != SASL_OK && ret != SASL_CONTINUE)
+ {
+ rc= MEMCACHED_AUTH_PROBLEM;
+ break;
+ }
+
+ request.message.header.request.opcode= PROTOCOL_BINARY_CMD_SASL_STEP;
+ request.message.header.request.bodylen= htonl(len + keylen);
+ } while (true);
+
+ /* Release resources */
+ sasl_dispose(&conn);
+
+ return memcached_set_error(*server, rc, MEMCACHED_AT);
+}
+
+static int get_username(void *context, int id, const char **result, unsigned int *len)
+{
+ if (!context || !result || (id != SASL_CB_USER && id != SASL_CB_AUTHNAME))
+ {
+ return SASL_BADPARAM;
+ }
+
+ *result= (char *)context;
+ if (len)
+ {
+ *len= (unsigned int)strlen(*result);
+ }
+
+ return SASL_OK;
+}
+
+static int get_password(sasl_conn_t *conn, void *context, int id,
+ sasl_secret_t **psecret)
+{
+ if (!conn || ! psecret || id != SASL_CB_PASS)
+ {
+ return SASL_BADPARAM;
+ }
+
+ *psecret= (sasl_secret_t *)context;
+
+ return SASL_OK;
+}
+
+memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr,
+ const char *username,
+ const char *password)
+{
+ if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0)
+ {
+ return MEMCACHED_NOT_SUPPORTED;
+ }
+
+ if (ptr == NULL or username == NULL or password == NULL)
+ {
+ return MEMCACHED_INVALID_ARGUMENTS;
+ }
+
+ memcached_destroy_sasl_auth_data(ptr);
+
+ sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(ptr, 4, sizeof(sasl_callback_t));
+ size_t password_length= strlen(password);
+ size_t username_length= strlen(username);
+ char *name= (char *)libmemcached_malloc(ptr, username_length +1);
+ sasl_secret_t *secret= (sasl_secret_t*)libmemcached_malloc(ptr, password_length +1 + sizeof(sasl_secret_t));
+
+ if (callbacks == NULL or name == NULL or secret == NULL)
+ {
+ libmemcached_free(ptr, callbacks);
+ libmemcached_free(ptr, name);
+ libmemcached_free(ptr, secret);
+ return memcached_set_error(*ptr, MEMCACHED_MEMORY_ALLOCATION_FAILURE, MEMCACHED_AT);
+ }
+
+ secret->len= password_length;
+ memcpy(secret->data, password, password_length);
+ secret->data[password_length]= 0;
+
+ callbacks[0].id= SASL_CB_USER;
+ callbacks[0].proc= (int (*)())get_username;
+ callbacks[0].context= strncpy(name, username, username_length +1);
+ callbacks[1].id= SASL_CB_AUTHNAME;
+ callbacks[1].proc= (int (*)())get_username;
+ callbacks[1].context= name;
+ callbacks[2].id= SASL_CB_PASS;
+ callbacks[2].proc= (int (*)())get_password;
+ callbacks[2].context= secret;
+ callbacks[3].id= SASL_CB_LIST_END;
+
+ ptr->sasl.callbacks= callbacks;
+ ptr->sasl.is_allocated= true;
+
+ return MEMCACHED_SUCCESS;
+}
+
+memcached_return_t memcached_destroy_sasl_auth_data(memcached_st *ptr)
+{
+ if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0)
+ {
+ return MEMCACHED_NOT_SUPPORTED;
+ }
+
+ if (ptr == NULL)
+ {
+ return MEMCACHED_INVALID_ARGUMENTS;
+ }
+
+ if (ptr->sasl.callbacks == NULL)
+ {
+ return MEMCACHED_SUCCESS;
+ }
+
+ if (ptr->sasl.is_allocated)
+ {
+ libmemcached_free(ptr, ptr->sasl.callbacks[0].context);
+ libmemcached_free(ptr, ptr->sasl.callbacks[2].context);
+ libmemcached_free(ptr, (void*)ptr->sasl.callbacks);
+ ptr->sasl.is_allocated= false;
+ }
+
+ ptr->sasl.callbacks= NULL;
+
+ return MEMCACHED_SUCCESS;
+}
+
+memcached_return_t memcached_clone_sasl(memcached_st *clone, const memcached_st *source)
+{
+ if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0)
+ {
+ return MEMCACHED_NOT_SUPPORTED;
+ }
+
+ if (clone == NULL or source == NULL)
+ {
+ return MEMCACHED_INVALID_ARGUMENTS;
+ }
+
+ if (source->sasl.callbacks == NULL)
+ {
+ return MEMCACHED_SUCCESS;
+ }
+
+ /* Hopefully we are using our own callback mechanisms.. */
+ if (source->sasl.callbacks[0].id == SASL_CB_USER &&
+ source->sasl.callbacks[0].proc == (int (*)())get_username &&
+ source->sasl.callbacks[1].id == SASL_CB_AUTHNAME &&
+ source->sasl.callbacks[1].proc == (int (*)())get_username &&
+ source->sasl.callbacks[2].id == SASL_CB_PASS &&
+ source->sasl.callbacks[2].proc == (int (*)())get_password &&
+ source->sasl.callbacks[3].id == SASL_CB_LIST_END)
+ {
+ sasl_secret_t *secret= (sasl_secret_t *)source->sasl.callbacks[2].context;
+ return memcached_set_sasl_auth_data(clone,
+ (const char*)source->sasl.callbacks[0].context,
+ (const char*)secret->data);
+ }
+
+ /*
+ * But we're not. It may work if we know what the user tries to pass
+ * into the list, but if we don't know the ID we don't know how to handle
+ * the context...
+ */
+ size_t total= 0;
+
+ while (source->sasl.callbacks[total].id != SASL_CB_LIST_END)
+ {
+ switch (source->sasl.callbacks[total].id)
+ {
+ case SASL_CB_USER:
+ case SASL_CB_AUTHNAME:
+ case SASL_CB_PASS:
+ break;
+ default:
+ /* I don't know how to deal with this... */
+ return MEMCACHED_NOT_SUPPORTED;
+ }
+
+ ++total;
+ }
+
+ sasl_callback_t *callbacks= (sasl_callback_t*)libmemcached_calloc(clone, total +1, sizeof(sasl_callback_t));
+ if (callbacks == NULL)
+ {
+ return MEMCACHED_MEMORY_ALLOCATION_FAILURE;
+ }
+ memcpy(callbacks, source->sasl.callbacks, (total + 1) * sizeof(sasl_callback_t));
+
+ /* Now update the context... */
+ for (size_t x= 0; x < total; ++x)
+ {
+ if (callbacks[x].id == SASL_CB_USER || callbacks[x].id == SASL_CB_AUTHNAME)
+ {
+ callbacks[x].context= (sasl_callback_t*)libmemcached_malloc(clone, strlen((const char*)source->sasl.callbacks[x].context));
+
+ if (callbacks[x].context == NULL)
+ {
+ /* Failed to allocate memory, clean up previously allocated memory */
+ for (size_t y= 0; y < x; ++y)
+ {
+ libmemcached_free(clone, clone->sasl.callbacks[y].context);
+ }
+
+ libmemcached_free(clone, callbacks);
+ return MEMCACHED_MEMORY_ALLOCATION_FAILURE;
+ }
+ strncpy((char*)callbacks[x].context, (const char*)source->sasl.callbacks[x].context, sizeof(callbacks[x].context));
+ }
+ else
+ {
+ sasl_secret_t *src= (sasl_secret_t *)source->sasl.callbacks[x].context;
+ sasl_secret_t *n= (sasl_secret_t*)libmemcached_malloc(clone, src->len + 1 + sizeof(*n));
+ if (n == NULL)
+ {
+ /* Failed to allocate memory, clean up previously allocated memory */
+ for (size_t y= 0; y < x; ++y)
+ {
+ libmemcached_free(clone, clone->sasl.callbacks[y].context);
+ }
+
+ libmemcached_free(clone, callbacks);
+ return MEMCACHED_MEMORY_ALLOCATION_FAILURE;
+ }
+ memcpy(n, src, src->len + 1 + sizeof(*n));
+ callbacks[x].context= n;
+ }
+ }
+
+ clone->sasl.callbacks= callbacks;
+ clone->sasl.is_allocated= true;
+
+ return MEMCACHED_SUCCESS;
+}
+
+#else
+
+void memcached_set_sasl_callbacks(memcached_st *, const sasl_callback_t *)
+{
+}
+
+sasl_callback_t *memcached_get_sasl_callbacks(memcached_st *)
+{
+ return NULL;
+}
+
+memcached_return_t memcached_set_sasl_auth_data(memcached_st *, const char *, const char *)
+{
+ return MEMCACHED_NOT_SUPPORTED;
+}
+
+memcached_return_t memcached_clone_sasl(memcached_st *, const memcached_st *)
+{
+ return MEMCACHED_NOT_SUPPORTED;
+}
+
+#endif