Merge pull request #408 from dtolnay/ctype

Pass improper ctype references by void pointer
diff --git a/macro/src/expand.rs b/macro/src/expand.rs
index c86dc91..b5d3067 100644
--- a/macro/src/expand.rs
+++ b/macro/src/expand.rs
@@ -208,7 +208,7 @@
     });
     let args = efn.args.iter().map(|arg| {
         let ident = &arg.ident;
-        let ty = expand_extern_type(&arg.ty, types);
+        let ty = expand_extern_type(&arg.ty, types, true);
         if arg.ty == RustString {
             quote!(#ident: *const #ty)
         } else if let Type::RustVec(_) = arg.ty {
@@ -225,11 +225,11 @@
     let ret = if efn.throws {
         quote!(-> ::cxx::private::Result)
     } else {
-        expand_extern_return_type(&efn.ret, types)
+        expand_extern_return_type(&efn.ret, types, true)
     };
     let mut outparam = None;
     if indirect_return(efn, types) {
-        let ret = expand_extern_type(efn.ret.as_ref().unwrap(), types);
+        let ret = expand_extern_type(efn.ret.as_ref().unwrap(), types, true);
         outparam = Some(quote!(__return: *mut #ret));
     }
     let link_name = mangle::extern_fn(efn, types);
@@ -287,6 +287,10 @@
                     None => quote!(::cxx::private::RustVec::from_ref(#var)),
                     Some(_) => quote!(::cxx::private::RustVec::from_mut(#var)),
                 },
+                inner if types.is_considered_improper_ctype(inner) => match ty.mutability {
+                    None => quote!(#var as *const #inner as *const ::std::ffi::c_void),
+                    Some(_) => quote!(#var as *mut #inner as *mut ::std::ffi::c_void),
+                },
                 _ => quote!(#var),
             },
             Type::Str(_) => quote!(::cxx::private::RustStr::from(#var)),
@@ -323,7 +327,7 @@
         .collect::<TokenStream>();
     let local_name = format_ident!("__{}", efn.name.rust);
     let call = if indirect_return {
-        let ret = expand_extern_type(efn.ret.as_ref().unwrap(), types);
+        let ret = expand_extern_type(efn.ret.as_ref().unwrap(), types, true);
         setup.extend(quote! {
             let mut __return = ::std::mem::MaybeUninit::<#ret>::uninit();
         });
@@ -376,6 +380,10 @@
                         None => quote!(#call.as_vec()),
                         Some(_) => quote!(#call.as_mut_vec()),
                     },
+                    inner if types.is_considered_improper_ctype(inner) => {
+                        let mutability = ty.mutability;
+                        quote!(&#mutability *#call.cast())
+                    }
                     _ => call,
                 },
                 Type::Str(_) => quote!(#call.as_str()),
@@ -508,7 +516,7 @@
     });
     let args = sig.args.iter().map(|arg| {
         let ident = &arg.ident;
-        let ty = expand_extern_type(&arg.ty, types);
+        let ty = expand_extern_type(&arg.ty, types, false);
         if types.needs_indirect_abi(&arg.ty) {
             quote!(#ident: *mut #ty)
         } else {
@@ -609,7 +617,7 @@
     let mut outparam = None;
     let indirect_return = indirect_return(sig, types);
     if indirect_return {
-        let ret = expand_extern_type(sig.ret.as_ref().unwrap(), types);
+        let ret = expand_extern_type(sig.ret.as_ref().unwrap(), types, false);
         outparam = Some(quote!(__return: *mut #ret,));
     }
     if sig.throws {
@@ -627,7 +635,7 @@
     let ret = if sig.throws {
         quote!(-> ::cxx::private::Result)
     } else {
-        expand_extern_return_type(&sig.ret, types)
+        expand_extern_return_type(&sig.ret, types, false)
     };
 
     let pointer = match invoke {
@@ -939,15 +947,15 @@
         .map_or(false, |ret| sig.throws || types.needs_indirect_abi(ret))
 }
 
-fn expand_extern_type(ty: &Type, types: &Types) -> TokenStream {
+fn expand_extern_type(ty: &Type, types: &Types, proper: bool) -> TokenStream {
     match ty {
         Type::Ident(ident) if ident.rust == RustString => quote!(::cxx::private::RustString),
         Type::RustBox(ty) | Type::UniquePtr(ty) => {
-            let inner = expand_extern_type(&ty.inner, types);
+            let inner = expand_extern_type(&ty.inner, types, proper);
             quote!(*mut #inner)
         }
         Type::RustVec(ty) => {
-            let elem = expand_extern_type(&ty.inner, types);
+            let elem = expand_extern_type(&ty.inner, types, proper);
             quote!(::cxx::private::RustVec<#elem>)
         }
         Type::Ref(ty) => {
@@ -957,9 +965,13 @@
                     quote!(&#mutability ::cxx::private::RustString)
                 }
                 Type::RustVec(ty) => {
-                    let inner = expand_extern_type(&ty.inner, types);
+                    let inner = expand_extern_type(&ty.inner, types, proper);
                     quote!(&#mutability ::cxx::private::RustVec<#inner>)
                 }
+                inner if proper && types.is_considered_improper_ctype(inner) => match mutability {
+                    None => quote!(*const ::std::ffi::c_void),
+                    Some(_) => quote!(*#mutability ::std::ffi::c_void),
+                },
                 _ => quote!(#ty),
             }
         }
@@ -969,11 +981,11 @@
     }
 }
 
-fn expand_extern_return_type(ret: &Option<Type>, types: &Types) -> TokenStream {
+fn expand_extern_return_type(ret: &Option<Type>, types: &Types, proper: bool) -> TokenStream {
     let ret = match ret {
         Some(ret) if !types.needs_indirect_abi(ret) => ret,
         _ => return TokenStream::new(),
     };
-    let ty = expand_extern_type(ret, types);
+    let ty = expand_extern_type(ret, types, proper);
     quote!(-> #ty)
 }
diff --git a/syntax/improper.rs b/syntax/improper.rs
new file mode 100644
index 0000000..6fd3162
--- /dev/null
+++ b/syntax/improper.rs
@@ -0,0 +1,36 @@
+use self::ImproperCtype::*;
+use crate::syntax::atom::Atom::{self, *};
+use crate::syntax::{Type, Types};
+use proc_macro2::Ident;
+
+pub enum ImproperCtype<'a> {
+    Definite(bool),
+    Depends(&'a Ident),
+}
+
+impl<'a> Types<'a> {
+    // yes, no, maybe
+    pub fn determine_improper_ctype(&self, ty: &Type) -> ImproperCtype<'a> {
+        match ty {
+            Type::Ident(ident) => {
+                let ident = &ident.rust;
+                if let Some(atom) = Atom::from(ident) {
+                    Definite(atom == RustString)
+                } else if let Some(strct) = self.structs.get(ident) {
+                    Depends(&strct.name.rust) // iterate to fixed-point
+                } else {
+                    Definite(self.rust.contains(ident))
+                }
+            }
+            Type::RustBox(_)
+            | Type::RustVec(_)
+            | Type::Str(_)
+            | Type::Fn(_)
+            | Type::Void(_)
+            | Type::Slice(_)
+            | Type::SliceRefU8(_) => Definite(true),
+            Type::UniquePtr(_) | Type::CxxVector(_) => Definite(false),
+            Type::Ref(ty) => self.determine_improper_ctype(&ty.inner),
+        }
+    }
+}
diff --git a/syntax/mod.rs b/syntax/mod.rs
index ac6350a..b742e37 100644
--- a/syntax/mod.rs
+++ b/syntax/mod.rs
@@ -10,6 +10,7 @@
 pub mod file;
 pub mod ident;
 mod impls;
+mod improper;
 pub mod mangle;
 mod names;
 pub mod namespace;
diff --git a/syntax/types.rs b/syntax/types.rs
index 67b041b..5d6d582 100644
--- a/syntax/types.rs
+++ b/syntax/types.rs
@@ -1,4 +1,5 @@
 use crate::syntax::atom::Atom::{self, *};
+use crate::syntax::improper::ImproperCtype;
 use crate::syntax::report::Errors;
 use crate::syntax::set::OrderedSet as Set;
 use crate::syntax::{
@@ -19,6 +20,7 @@
     pub required_trivial: Map<&'a Ident, TrivialReason<'a>>,
     pub explicit_impls: Set<&'a Impl>,
     pub resolutions: Map<&'a Ident, &'a Pair>,
+    pub struct_improper_ctypes: UnorderedSet<&'a Ident>,
 }
 
 impl<'a> Types<'a> {
@@ -32,6 +34,7 @@
         let mut untrusted = Map::new();
         let mut explicit_impls = Set::new();
         let mut resolutions = Map::new();
+        let struct_improper_ctypes = UnorderedSet::new();
 
         fn visit<'a>(all: &mut Set<&'a Type>, ty: &'a Type) {
             all.insert(ty);
@@ -190,7 +193,7 @@
             }
         }
 
-        Types {
+        let mut types = Types {
             all,
             structs,
             enums,
@@ -201,7 +204,34 @@
             required_trivial,
             explicit_impls,
             resolutions,
+            struct_improper_ctypes,
+        };
+
+        let mut unresolved_structs: Vec<&Ident> = types.structs.keys().copied().collect();
+        let mut new_information = true;
+        while new_information {
+            new_information = false;
+            unresolved_structs.retain(|ident| {
+                let mut retain = false;
+                for var in &types.structs[ident].fields {
+                    if match types.determine_improper_ctype(&var.ty) {
+                        ImproperCtype::Depends(inner) => {
+                            retain = true;
+                            types.struct_improper_ctypes.contains(inner)
+                        }
+                        ImproperCtype::Definite(improper) => improper,
+                    } {
+                        types.struct_improper_ctypes.insert(ident);
+                        new_information = true;
+                        return false;
+                    }
+                }
+                // If all fields definite false, remove from unresolved_structs.
+                retain
+            });
         }
+
+        types
     }
 
     pub fn needs_indirect_abi(&self, ty: &Type) -> bool {
@@ -227,6 +257,19 @@
         false
     }
 
+    // Types that trigger rustc's default #[warn(improper_ctypes)] lint, even if
+    // they may be otherwise unproblematic to mention in an extern signature.
+    // For example in a signature like `extern "C" fn(*const String)`, rustc
+    // refuses to believe that C could know how to supply us with a pointer to a
+    // Rust String, even though C could easily have obtained that pointer
+    // legitimately from a Rust call.
+    pub fn is_considered_improper_ctype(&self, ty: &Type) -> bool {
+        match self.determine_improper_ctype(ty) {
+            ImproperCtype::Definite(improper) => improper,
+            ImproperCtype::Depends(ident) => self.struct_improper_ctypes.contains(ident),
+        }
+    }
+
     pub fn resolve(&self, ident: &ResolvableName) -> &Pair {
         self.resolutions
             .get(&ident.rust)
diff --git a/tests/ffi/lib.rs b/tests/ffi/lib.rs
index f40ec6c..af8a9f5 100644
--- a/tests/ffi/lib.rs
+++ b/tests/ffi/lib.rs
@@ -69,6 +69,10 @@
         z: usize,
     }
 
+    struct SharedString {
+        msg: String,
+    }
+
     enum Enum {
         AVal,
         BVal = 2020,
@@ -169,6 +173,7 @@
         fn c_take_ref_rust_vec_string(v: &Vec<String>);
         fn c_take_ref_rust_vec_index(v: &Vec<u8>);
         fn c_take_ref_rust_vec_copy(v: &Vec<u8>);
+        fn c_take_ref_shared_string(s: &SharedString) -> &SharedString;
         fn c_take_callback(callback: fn(String) -> usize);
         fn c_take_enum(e: Enum);
         fn c_take_ns_enum(e: AEnum);
diff --git a/tests/ffi/tests.cc b/tests/ffi/tests.cc
index a03ec2b..747788f 100644
--- a/tests/ffi/tests.cc
+++ b/tests/ffi/tests.cc
@@ -385,6 +385,13 @@
   }
 }
 
+const SharedString &c_take_ref_shared_string(const SharedString &s) {
+  if (std::string(s.msg) == "2020") {
+    cxx_test_suite_set_correct();
+  }
+  return s;
+}
+
 void c_take_callback(rust::Fn<size_t(rust::String)> callback) {
   callback("2020");
 }
diff --git a/tests/ffi/tests.h b/tests/ffi/tests.h
index e551048..7fc5633 100644
--- a/tests/ffi/tests.h
+++ b/tests/ffi/tests.h
@@ -36,6 +36,7 @@
 
 struct R;
 struct Shared;
+struct SharedString;
 enum class Enum : uint16_t;
 
 class C {
@@ -134,6 +135,7 @@
 void c_take_ref_rust_vec_string(const rust::Vec<rust::String> &v);
 void c_take_ref_rust_vec_index(const rust::Vec<uint8_t> &v);
 void c_take_ref_rust_vec_copy(const rust::Vec<uint8_t> &v);
+const SharedString &c_take_ref_shared_string(const SharedString &s);
 void c_take_callback(rust::Fn<size_t(rust::String)> callback);
 void c_take_enum(Enum e);
 void c_take_ns_enum(::A::AEnum e);
diff --git a/tests/test.rs b/tests/test.rs
index a90b8fc..2c6f062 100644
--- a/tests/test.rs
+++ b/tests/test.rs
@@ -137,6 +137,9 @@
     check!(ffi::c_take_ref_rust_vec(&test_vec));
     check!(ffi::c_take_ref_rust_vec_index(&test_vec));
     check!(ffi::c_take_ref_rust_vec_copy(&test_vec));
+    check!(ffi::c_take_ref_shared_string(&ffi::SharedString {
+        msg: "2020".to_owned()
+    }));
     let ns_shared_test_vec = vec![ffi::AShared { z: 1010 }, ffi::AShared { z: 1011 }];
     check!(ffi::c_take_rust_vec_ns_shared(ns_shared_test_vec));
     let nested_ns_shared_test_vec = vec![ffi::ABShared { z: 1010 }, ffi::ABShared { z: 1011 }];