Add support for CUDA9 half semantics
diff --git a/Makefile b/Makefile
index 8f34fcb..c37b7f7 100644
--- a/Makefile
+++ b/Makefile
@@ -54,7 +54,7 @@
 
 NCCL_MAJOR   := 1
 NCCL_MINOR   := 3
-NCCL_PATCH   := 4
+NCCL_PATCH   := 5
 CXXFLAGS  += -DNCCL_MAJOR=$(NCCL_MAJOR) -DNCCL_MINOR=$(NCCL_MINOR) -DNCCL_PATCH=$(NCCL_PATCH)
 
 CUDA_VERSION ?= $(shell ls $(CUDA_LIB)/libcudart.so.* | head -1 | rev | cut -d "." -f -2 | rev)
diff --git a/src/common_kernel.h b/src/common_kernel.h
index cc71f8a..b96519f 100644
--- a/src/common_kernel.h
+++ b/src/common_kernel.h
@@ -35,25 +35,33 @@
   return *ptr;
 }
 
-#ifdef CUDA_HAS_HALF
-template<> inline __device__
-half vFetch<half>(const volatile half* ptr) {
-  half r;
-  r.x = ptr->x;
-  return r;
-}
-#endif
-
 template<typename T> inline __device__
 void vStore(volatile T* ptr, const T val) {
   *ptr = val;
 }
 
 #ifdef CUDA_HAS_HALF
+#if CUDART_VERSION < 9000
+template<> inline __device__
+half vFetch<half>(const volatile half* ptr) {
+  half r;
+  r.x = ptr->x;
+  return r;
+}
 template<> inline __device__
 void vStore<half>(volatile half* ptr, const half val) {
   ptr->x = val.x;
 }
+#else
+template<> inline __device__
+half vFetch<half>(const volatile half* ptr) {
+  return *((half*)ptr);
+}
+template<> inline __device__
+void vStore<half>(volatile half* ptr, const half val) {
+  *((half*)ptr) = val;
+}
+#endif
 #endif
 
 __device__ unsigned int spinct;
@@ -125,24 +133,22 @@
 #ifdef CUDA_HAS_HALF
 template<class FUNC>
 struct MULTI<FUNC, half> {
-  static_assert(sizeof(PackType) == 2 * sizeof(float),
-      "PackType must be twice the size of float.");
-  union converter {
-    PackType storage;
-    struct {
-      half2 a, b;
-    };
+  static_assert(sizeof(PackType) == 4 * sizeof(half),
+      "PackType must be four times the size of half.");
+
+  struct PackHalf2 {
+    half2 a, b;
   };
 
   __device__ PackType operator()(const PackType x, const PackType y) const {
-    converter cx, cy, cr;
-    cx.storage = x;
-    cy.storage = y;
+    struct PackHalf2 cx, cy, cr;
+    cx = *(reinterpret_cast<const struct PackHalf2*>(&x));
+    cy = *(reinterpret_cast<const struct PackHalf2*>(&y));
 
     cr.a = FUNC()(cx.a, cy.a);
     cr.b = FUNC()(cx.b, cy.b);
 
-    return cr.storage;
+    return *(reinterpret_cast<PackType*>(&cr));
   }
 };
 #endif
diff --git a/src/copy_kernel.h b/src/copy_kernel.h
index 8464699..0f69748 100644
--- a/src/copy_kernel.h
+++ b/src/copy_kernel.h
@@ -24,9 +24,7 @@
     return x;
   }
   __device__ half operator()(const half x, const half y) const {
-    half r;
-    r.x = x.x;
-    return r;
+    return x;
   }
 };
 #endif