Check the return code from fgets in memcapable
[m6w6/libmemcached] / clients / memcapable.c
index 08524c495647de4aa300624b40931445d9ab64b8..6ef5bc69285028a3787ae14a98caf98418c54110 100644 (file)
 #include "config.h"
 #include <pthread.h>
 #include <sys/types.h>
-#include <sys/socket.h>
-#include <netdb.h>
-#include <arpa/inet.h>
-#include <netinet/in.h>
-#include <netinet/tcp.h>
 #include <fcntl.h>
 #include <signal.h>
 #include <stdio.h>
 #include <inttypes.h>
 #include <stdbool.h>
 #include <unistd.h>
-#include <poll.h>
 #include <ctype.h>
 
+#include <libmemcached/memcached.h>
 #include <libmemcached/memcached/protocol_binary.h>
 #include <libmemcached/byteorder.h>
+#include "utilities.h"
 
 #ifdef linux
 /* /usr/include/netinet/in.h defines macros from ntohs() to _bswap_nn to
@@ -48,7 +44,7 @@
 /* Should we generate coredumps when we enounter an error (-c) */
 static bool do_core= false;
 /* connection to the server */
-static int sock;
+static memcached_socket_t sock;
 /* Should the output from test failures be verbose or quiet? */
 static bool verbose= false;
 
@@ -112,14 +108,23 @@ static struct addrinfo *lookuphost(const char *hostname, const char *port)
  * Set the socket in nonblocking mode
  * @return -1 if failure, the socket otherwise
  */
-static int set_noblock(void)
+static memcached_socket_t set_noblock(void)
 {
+#ifdef WIN32
+  u_long arg = 1;
+  if (ioctlsocket(sock, FIONBIO, &arg) == SOCKET_ERROR)
+  {
+    perror("Failed to set nonblocking io");
+    closesocket(sock);
+    return INVALID_SOCKET;
+  }
+#else
   int flags= fcntl(sock, F_GETFL, 0);
   if (flags == -1)
   {
     perror("Failed to get socket flags");
-    close(sock);
-    return -1;
+    closesocket(sock);
+    return INVALID_SOCKET;
   }
 
   if ((flags & O_NONBLOCK) != O_NONBLOCK)
@@ -127,11 +132,11 @@ static int set_noblock(void)
     if (fcntl(sock, F_SETFL, flags | O_NONBLOCK) == -1)
     {
       perror("Failed to set socket to nonblocking mode");
-      close(sock);
-      return -1;
+      closesocket(sock);
+      return INVALID_SOCKET;
     }
   }
-
+#endif
   return sock;
 }
 
@@ -141,28 +146,30 @@ static int set_noblock(void)
  * @param port the port number (or service) to connect to
  * @return positive integer if success, -1 otherwise
  */
-static int connect_server(const char *hostname, const char *port)
+static memcached_socket_t connect_server(const char *hostname, const char *port)
 {
   struct addrinfo *ai= lookuphost(hostname, port);
-  sock= -1;
+  sock= INVALID_SOCKET;
   if (ai != NULL)
   {
-    if ((sock=socket(ai->ai_family, ai->ai_socktype,
-                     ai->ai_protocol)) != -1)
+    if ((sock= socket(ai->ai_family, ai->ai_socktype,
+                      ai->ai_protocol)) != INVALID_SOCKET)
     {
-      if (connect(sock, ai->ai_addr, ai->ai_addrlen) == -1)
+      if (connect(sock, ai->ai_addr, ai->ai_addrlen) == SOCKET_ERROR)
       {
         fprintf(stderr, "Failed to connect socket: %s\n",
-                strerror(errno));
-        close(sock);
-        sock= -1;
+                strerror(get_socket_errno()));
+        closesocket(sock);
+        sock= INVALID_SOCKET;
       }
       else
       {
         sock= set_noblock();
       }
-    } else
-      fprintf(stderr, "Failed to create socket: %s\n", strerror(errno));
+    }
+    else
+      fprintf(stderr, "Failed to create socket: %s\n",
+              strerror(get_socket_errno()));
 
     freeaddrinfo(ai);
   }
@@ -170,28 +177,29 @@ static int connect_server(const char *hostname, const char *port)
   return sock;
 }
 
-static ssize_t timeout_io_op(int fd, short direction, void *buf, size_t len)
+static ssize_t timeout_io_op(memcached_socket_t fd, short direction, void *buf, size_t len)
 {
   ssize_t ret;
 
   if (direction == POLLOUT)
-    ret= write(fd, buf, len);
+     ret= send(fd, buf, len, 0);
   else
-    ret= read(fd, buf, len);
+     ret= recv(fd, buf, len, 0);
 
-  if (ret == -1 && errno == EWOULDBLOCK) {
+  if (ret == SOCKET_ERROR && get_socket_errno() == EWOULDBLOCK) {
     struct pollfd fds= {
       .events= direction,
       .fd= fd
     };
+
     int err= poll(&fds, 1, timeout * 1000);
 
     if (err == 1)
     {
       if (direction == POLLOUT)
-        ret= write(fd, buf, len);
+         ret= send(fd, buf, len, 0);
       else
-        ret= read(fd, buf, len);
+         ret= recv(fd, buf, len, 0);
     }
     else if (err == 0)
     {
@@ -245,7 +253,7 @@ static enum test_return retry_write(const void* buf, size_t len)
     size_t num_bytes= len - offset;
     ssize_t nw= timeout_io_op(sock, POLLOUT, (void*)(ptr + offset), num_bytes);
     if (nw == -1)
-      verify(errno == EINTR || errno == EAGAIN);
+      verify(get_socket_errno() == EINTR || get_socket_errno() == EAGAIN);
     else
       offset+= (size_t)nw;
   } while (offset < len);
@@ -295,7 +303,8 @@ static enum test_return retry_read(void *buf, size_t len)
     ssize_t nr= timeout_io_op(sock, POLLIN, ((char*) buf) + offset, len - offset);
     switch (nr) {
     case -1 :
-      verify(errno == EINTR || errno == EAGAIN);
+       fprintf(stderr, "Errno: %d %s\n", get_socket_errno(), strerror(errno));
+      verify(get_socket_errno() == EINTR || get_socket_errno() == EAGAIN);
       break;
     case 0:
       return TEST_FAIL;
@@ -313,7 +322,7 @@ static enum test_return retry_read(void *buf, size_t len)
  */
 static enum test_return recv_packet(response *rsp)
 {
-  execute(retry_read(rsp, sizeof (protocol_binary_response_no_extras)));
+  execute(retry_read(rsp, sizeof(protocol_binary_response_no_extras)));
 
   /* Fix the byte order in the packet header */
   rsp->plain.message.header.response.keylen=
@@ -685,10 +694,12 @@ static enum test_return test_binary_set_impl(const char* key, uint8_t cc)
   cmd.plain.message.header.request.cas=
           htonll(rsp.plain.message.header.response.cas - 1);
   execute(resend_packet(&cmd));
+  execute(send_binary_noop());
   execute(recv_packet(&rsp));
   verify(validate_response_header(&rsp, cc, PROTOCOL_BINARY_RESPONSE_KEY_EEXISTS));
+  execute(receive_binary_noop());
 
-  return test_binary_noop();
+  return TEST_PASS;
 }
 
 static enum test_return test_binary_set(void)
@@ -725,7 +736,9 @@ static enum test_return test_binary_add_impl(const char* key, uint8_t cc)
       else
         expected_result= PROTOCOL_BINARY_RESPONSE_KEY_EEXISTS;
 
+      execute(send_binary_noop());
       execute(recv_packet(&rsp));
+      execute(receive_binary_noop());
       verify(validate_response_header(&rsp, cc, expected_result));
     }
     else
@@ -782,7 +795,9 @@ static enum test_return test_binary_replace_impl(const char* key, uint8_t cc)
       else
         expected_result=PROTOCOL_BINARY_RESPONSE_SUCCESS;
 
+      execute(send_binary_noop());
       execute(recv_packet(&rsp));
+      execute(receive_binary_noop());
       verify(validate_response_header(&rsp, cc, expected_result));
 
       if (ii == 0)
@@ -809,7 +824,9 @@ static enum test_return test_binary_replace_impl(const char* key, uint8_t cc)
   cmd.plain.message.header.request.cas=
           htonll(rsp.plain.message.header.response.cas - 1);
   execute(resend_packet(&cmd));
+  execute(send_binary_noop());
   execute(recv_packet(&rsp));
+  execute(receive_binary_noop());
   verify(validate_response_header(&rsp, cc, PROTOCOL_BINARY_RESPONSE_KEY_EEXISTS));
 
   return TEST_PASS;
@@ -833,8 +850,10 @@ static enum test_return test_binary_delete_impl(const char *key, uint8_t cc)
 
   /* The delete shouldn't work the first time, because the item isn't there */
   execute(send_packet(&cmd));
+  execute(send_binary_noop());
   execute(recv_packet(&rsp));
   verify(validate_response_header(&rsp, cc, PROTOCOL_BINARY_RESPONSE_KEY_ENOENT));
+  execute(receive_binary_noop());
   execute(binary_set_item(key, key));
 
   /* The item should be present now, resend*/
@@ -1861,8 +1880,10 @@ int main(int argc, char **argv)
   const char *hostname= "localhost";
   const char *port= "11211";
   int cmd;
+  bool prompt= false;
+  const char *testname= NULL;
 
-  while ((cmd= getopt(argc, argv, "t:vch:p:?")) != EOF)
+  while ((cmd= getopt(argc, argv, "t:vch:p:PT:?")) != EOF)
   {
     switch (cmd) {
     case 't':
@@ -1881,30 +1902,64 @@ int main(int argc, char **argv)
       break;
     case 'p': port= optarg;
       break;
+    case 'P': prompt= true;
+      break;
+    case 'T': testname= optarg;
+       break;
     default:
-      fprintf(stderr, "Usage: %s [-h hostname] [-p port] [-c] [-v] [-t n]\n"
+      fprintf(stderr, "Usage: %s [-h hostname] [-p port] [-c] [-v] [-t n]"
+              " [-P] [-T testname]'\n"
               "\t-c\tGenerate coredump if a test fails\n"
               "\t-v\tVerbose test output (print out the assertion)\n"
-              "\t-t n\tSet the timeout for io-operations to n seconds\n",
+              "\t-t n\tSet the timeout for io-operations to n seconds\n"
+              "\t-P\tPrompt the user before starting a test.\n"
+              "\t\t\t\"skip\" will skip the test\n"
+              "\t\t\t\"quit\" will terminate memcapable\n"
+              "\t\t\tEverything else will start the test\n"
+              "\t-T n\tJust run the test named n\n",
               argv[0]);
       return 1;
     }
   }
 
+  initialize_sockets();
   sock= connect_server(hostname, port);
-  if (sock == -1)
+  if (sock == INVALID_SOCKET)
   {
     fprintf(stderr, "Failed to connect to <%s:%s>: %s\n",
-            hostname, port, strerror(errno));
+            hostname, port, strerror(get_socket_errno()));
     return 1;
   }
 
   for (int ii= 0; testcases[ii].description != NULL; ++ii)
   {
+    if (testname != NULL && strcmp(testcases[ii].description, testname) != 0)
+       continue;
+
     ++total;
     fprintf(stdout, "%-40s", testcases[ii].description);
     fflush(stdout);
 
+    if (prompt)
+    {
+      fprintf(stdout, "\nPress <return> when you are ready? ");
+      char buffer[80] = {0};
+      if (fgets(buffer, sizeof(buffer), stdin) != NULL) {
+        if (strncmp(buffer, "skip", 4) == 0)
+        {
+          fprintf(stdout, "%-40s%s\n", testcases[ii].description,
+                  status_msg[TEST_SKIP]);
+          fflush(stdout);
+          continue;
+        }
+        if (strncmp(buffer, "quit", 4) == 0)
+          exit(0);
+      }
+
+      fprintf(stdout, "%-40s", testcases[ii].description);
+      fflush(stdout);
+    }
+
     bool reconnect= false;
     enum test_return ret= testcases[ii].function();
     if (ret == TEST_FAIL)
@@ -1920,18 +1975,18 @@ int main(int argc, char **argv)
     fprintf(stderr, "%s\n", status_msg[ret]);
     if (reconnect)
     {
-      (void) close(sock);
-      if ((sock=connect_server(hostname, port)) == -1)
+      closesocket(sock);
+      if ((sock= connect_server(hostname, port)) == INVALID_SOCKET)
       {
         fprintf(stderr, "Failed to connect to <%s:%s>: %s\n",
-                hostname, port, strerror(errno));
+                hostname, port, strerror(get_socket_errno()));
         fprintf(stderr, "%d of %d tests failed\n", failed, total);
         return 1;
       }
     }
   }
 
-  (void) close(sock);
+  closesocket(sock);
   if (failed == 0)
     fprintf(stdout, "All tests passed\n");
   else