Merge "Modify wmediumd crash when using api server"
diff --git a/Android.bp b/Android.bp
index fa862cf..e165e93 100644
--- a/Android.bp
+++ b/Android.bp
@@ -58,3 +58,18 @@
     stl: "none",
     static_executable: true,
 }
+
+cc_binary_host {
+    name: "wmediumd_ack_test_client",
+    srcs: [
+        "tests/wmediumd_ack_test_client.c",
+    ],
+    local_include_dirs: [
+        "wmediumd/inc",
+    ],
+    visibility: [
+       "//device/google/cuttlefish/build",
+    ],
+    stl: "none",
+    static_executable: true,
+}
diff --git a/tests/wmediumd_ack_test_client.c b/tests/wmediumd_ack_test_client.c
new file mode 100644
index 0000000..8eb172b
--- /dev/null
+++ b/tests/wmediumd_ack_test_client.c
@@ -0,0 +1,168 @@
+#include <errno.h>
+#include <getopt.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include "wmediumd/api.h"
+
+void print_help(int exit_code) {
+  printf(
+      "wmediumd_ack_test_client - test client for wmediumd crash that is "
+      "related with ack\n\n");
+  printf("Usage: wmediumd_ack_test_client -s PATH\n");
+  printf("  Options:\n");
+  printf("     - h : Print help\n");
+  printf("     - s : Path for unix socket of wmediumd api server\n");
+
+  exit(exit_code);
+}
+
+int write_fixed(int sock, void *data, int len) {
+  int remain = len;
+  int pos = 0;
+
+  while (remain > 0) {
+    int actual_written = write(sock, ((char *)data) + pos, remain);
+
+    if (actual_written <= 0) {
+      return actual_written;
+    }
+
+    remain -= actual_written;
+    pos += actual_written;
+  }
+
+  return pos;
+}
+
+int read_fixed(int sock, void *data, int len) {
+  int remain = len;
+  int pos = 0;
+
+  while (remain > 0) {
+    int actual_read = read(sock, ((char *)data) + pos, remain);
+
+    if (actual_read <= 0) {
+      return actual_read;
+    }
+
+    remain -= actual_read;
+    pos += actual_read;
+  }
+
+  return pos;
+}
+
+int wmediumd_send_packet(int sock, uint32_t type, void *data, uint32_t len) {
+  struct wmediumd_message_header header;
+
+  header.type = type;
+  header.data_len = len;
+
+  write_fixed(sock, &header, sizeof(uint32_t) * 2);
+
+  if (len != 0) {
+    write_fixed(sock, data, len);
+  }
+
+  return 0;
+}
+
+int wmediumd_read_packet(int sock) {
+  struct wmediumd_message_header header;
+
+  read_fixed(sock, &header, sizeof(uint32_t) * 2);
+
+  if (header.data_len != 0) {
+    char buf[4096];
+
+    read_fixed(sock, buf, header.data_len);
+  }
+
+  return 0;
+}
+
+int main(int argc, char **argv) {
+  int opt;
+  char *wmediumd_api_server_path = NULL;
+
+  while ((opt = getopt(argc, argv, "hs:")) != -1) {
+    switch (opt) {
+      case ':':
+        fprintf(stderr,
+                "error: Option `%c' "
+                "needs a value\n\n",
+                optopt);
+        break;
+      case 'h':
+        print_help(0);
+        break;
+      case 's':
+        if (wmediumd_api_server_path != NULL) {
+          fprintf(stderr,
+                  "error: You must provide just one option for `%c`\n\n",
+                  optopt);
+        }
+
+        wmediumd_api_server_path = strdup(optarg);
+        break;
+      default:
+        break;
+    }
+  }
+
+  if (wmediumd_api_server_path == NULL) {
+    fprintf(stderr, "error: must specify wmediumd api server path\n\n");
+    print_help(-1);
+  }
+
+  int sock = socket(AF_UNIX, SOCK_STREAM, 0);
+
+  struct sockaddr_un addr;
+
+  addr.sun_family = AF_UNIX;
+
+  if (strlen(wmediumd_api_server_path) >= sizeof(addr.sun_path)) {
+    fprintf(stderr, "error: unix socket path is too long(maximum %d)\n",
+            sizeof(addr.sun_path) - 1);
+    print_help(-1);
+  }
+
+  strncpy(addr.sun_path, wmediumd_api_server_path,
+          strlen(wmediumd_api_server_path));
+
+  if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
+    fprintf(stderr, "Cannot connect to %s\n", wmediumd_api_server_path);
+    return -1;
+  }
+
+  struct wmediumd_message_control control_message;
+
+  control_message.flags = WMEDIUMD_CTL_RX_ALL_FRAMES;
+
+  wmediumd_send_packet(sock, WMEDIUMD_MSG_REGISTER, NULL, 0);
+  wmediumd_read_packet(sock); /* Ack */
+  wmediumd_send_packet(sock, WMEDIUMD_MSG_SET_CONTROL, &control_message,
+                       sizeof(control_message));
+  wmediumd_read_packet(sock); /* Ack */
+
+  wmediumd_read_packet(sock);
+
+  /* Send packet while receiving packet from wmediumd */
+  wmediumd_send_packet(sock, WMEDIUMD_MSG_SET_CONTROL, &control_message,
+                       sizeof(control_message));
+  wmediumd_read_packet(sock);
+
+  wmediumd_send_packet(sock, WMEDIUMD_MSG_ACK, NULL, 0);
+
+  close(sock);
+
+  free(wmediumd_api_server_path);
+
+  return 0;
+}
diff --git a/wmediumd/wmediumd.c b/wmediumd/wmediumd.c
index d888a54..8b5d350 100644
--- a/wmediumd/wmediumd.c
+++ b/wmediumd/wmediumd.c
@@ -1065,14 +1065,6 @@
 	if (len != hdr.data_len)
 		goto disconnect;
 
-	if (client->wait_for_ack) {
-		assert(hdr.type == WMEDIUMD_MSG_ACK);
-		assert(hdr.data_len == 0);
-		client->wait_for_ack = false;
-		/* don't send a response to a response, of course */
-		return;
-	}
-
 	switch (hdr.type) {
 	case WMEDIUMD_MSG_REGISTER:
 		if (!list_empty(&client->list)) {
@@ -1136,7 +1128,11 @@
                 }
 		break;
 	case WMEDIUMD_MSG_ACK:
-		abort();
+		assert(client->wait_for_ack == true);
+		assert(hdr.data_len == 0);
+		client->wait_for_ack = false;
+		/* don't send a response to a response, of course */
+		return;
 	default:
 		response = WMEDIUMD_MSG_INVALID;
 		break;