2008-10-24 Mark Probst <mark.probst@gmail.com>
[mono.git] / mono / metadata / socket-io.c
index 63a28ae16e07710a778710462339dba42615a3d3..225aa585e6b15df6a1062f67ab86c87aee5a3468 100644 (file)
@@ -14,7 +14,9 @@
 #include <glib.h>
 #include <string.h>
 #include <stdlib.h>
+#ifdef HAVE_UNISTD_H
 #include <unistd.h>
+#endif
 #include <errno.h>
 
 #include <mono/metadata/object.h>
@@ -31,7 +33,9 @@
 #include <mono/metadata/threadpool-internals.h>
 #include <mono/metadata/domain-internals.h>
 
-#include <sys/time.h> 
+#ifdef HAVE_SYS_TIME_H
+#include <sys/time.h>
+#endif
 #ifdef HAVE_SYS_IOCTL_H
 #include <sys/ioctl.h>
 #endif
@@ -284,11 +288,18 @@ static gint32 convert_socketflags (gint32 sflags)
                flags |= MSG_PEEK;
        if (sflags & SocketFlags_DontRoute)
                flags |= MSG_DONTROUTE;
+#if 0
+       /* Ignore Partial - see bug 349688.  Don't return -1, because
+        * according to the comment in that bug ms runtime doesn't for
+        * UDP sockets (this means we will silently ignore it for TCP
+        * too)
+        */
        if (sflags & SocketFlags_Partial)
 #ifdef MSG_MORE
                flags |= MSG_MORE;
 #else
                return -1;      
+#endif
 #endif
        if (sflags & SocketFlags_MaxIOVectorLength)
                /* FIXME: Don't know what to do for MaxIOVectorLength query */
@@ -297,6 +308,12 @@ static gint32 convert_socketflags (gint32 sflags)
        return (flags ? flags : -1);
 }
 
+/*
+ * Returns:
+ *    0 on success (mapped mono_level and mono_name to system_level and system_name
+ *   -1 on error
+ *   -2 on non-fatal error (ie, must ignore)
+ */
 static gint32 convert_sockopt_level_and_name(MonoSocketOptionLevel mono_level,
                                             MonoSocketOptionName mono_name,
                                             int *system_level,
@@ -440,7 +457,19 @@ static gint32 convert_sockopt_level_and_name(MonoSocketOptionLevel mono_level,
                        *system_name = IP_PKTINFO;
                        break;
 #endif /* HAVE_IP_PKTINFO */
+
                case SocketOptionName_DontFragment:
+#ifdef HAVE_IP_DONTFRAGMENT
+                       *system_name = IP_DONTFRAGMENT;
+                       break;
+#elif defined HAVE_IP_MTU_DISCOVER
+                       /* Not quite the same */
+                       *system_name = IP_MTU_DISCOVER;
+                       break;
+#else
+                       /* If the flag is not available on this system, we can ignore this error */
+                       return (-2);
+#endif /* HAVE_IP_DONTFRAGMENT */
                case SocketOptionName_AddSourceMembership:
                case SocketOptionName_DropSourceMembership:
                case SocketOptionName_BlockSource:
@@ -579,44 +608,82 @@ static gint32 convert_sockopt_level_and_name(MonoSocketOptionLevel mono_level,
        return(0);
 }
 
-#define STASH_SYS_ASS(this) \
-       if(system_assembly == NULL) { \
-               system_assembly=mono_image_loaded ("System"); \
-               if (!system_assembly) { \
-                       MonoAssembly *sa = mono_assembly_open ("System.dll", NULL);     \
-                       if (!sa) g_assert_not_reached ();       \
-                       else {system_assembly = mono_assembly_get_image (sa);}  \
-               }       \
+static MonoImage *get_socket_assembly (void)
+{
+       static const char *version = NULL;
+       static gboolean moonlight;
+       static MonoImage *socket_assembly = NULL;
+       
+       if (version == NULL) {
+               version = mono_get_runtime_info ()->framework_version;
+               moonlight = !strcmp (version, "2.1");
        }
-
-static MonoImage *system_assembly=NULL;
-
+       
+       if (socket_assembly == NULL) {
+               if (moonlight) {
+                       socket_assembly = mono_image_loaded ("System.Net");
+                       if (!socket_assembly) {
+                               MonoAssembly *sa = mono_assembly_open ("System.Net.dll", NULL);
+                       
+                               if (!sa) {
+                                       g_assert_not_reached ();
+                               } else {
+                                       socket_assembly = mono_assembly_get_image (sa);
+                               }
+                       }
+               } else {
+                       socket_assembly = mono_image_loaded ("System");
+                       if (!socket_assembly) {
+                               MonoAssembly *sa = mono_assembly_open ("System.dll", NULL);
+                       
+                               if (!sa) {
+                                       g_assert_not_reached ();
+                               } else {
+                                       socket_assembly = mono_assembly_get_image (sa);
+                               }
+                       }
+               }
+       }
+       
+       return(socket_assembly);
+}
 
 #ifdef AF_INET6
 static gint32 get_family_hint(void)
 {
-       MonoClass *socket_class;
-       MonoClassField *ipv6_field, *ipv4_field;
-       gint32 ipv6_enabled = -1, ipv4_enabled = -1;
-       MonoVTable *vtable;
-
-       socket_class = mono_class_from_name (system_assembly,
-                                            "System.Net.Sockets", "Socket");
-       ipv4_field = mono_class_get_field_from_name (socket_class,
-                                                    "ipv4Supported");
-       ipv6_field = mono_class_get_field_from_name (socket_class,
-                                                    "ipv6Supported");
-       vtable = mono_class_vtable (mono_domain_get (), socket_class);
-
-       mono_field_static_get_value(vtable, ipv4_field, &ipv4_enabled);
-       mono_field_static_get_value(vtable, ipv6_field, &ipv6_enabled);
-
-       if(ipv4_enabled == 1 && ipv6_enabled == 1) {
-               return(PF_UNSPEC);
-       } else if(ipv4_enabled == 1) {
-               return(PF_INET);
-       } else {
-               return(PF_INET6);
+       MonoDomain *domain = mono_domain_get ();
+
+       if (!domain->inet_family_hint) {
+               MonoClass *socket_class;
+               MonoClassField *ipv6_field, *ipv4_field;
+               gint32 ipv6_enabled = -1, ipv4_enabled = -1;
+               MonoVTable *vtable;
+
+               socket_class = mono_class_from_name (get_socket_assembly (), "System.Net.Sockets", "Socket");
+               ipv4_field = mono_class_get_field_from_name (socket_class, "ipv4Supported");
+               ipv6_field = mono_class_get_field_from_name (socket_class, "ipv6Supported");
+               vtable = mono_class_vtable (mono_domain_get (), socket_class);
+               mono_runtime_class_init (vtable);
+
+               mono_field_static_get_value (vtable, ipv4_field, &ipv4_enabled);
+               mono_field_static_get_value (vtable, ipv6_field, &ipv6_enabled);
+
+               mono_domain_lock (domain);
+               if (ipv4_enabled == 1 && ipv6_enabled == 1) {
+                       domain->inet_family_hint = 1;
+               } else if (ipv4_enabled == 1) {
+                       domain->inet_family_hint = 2;
+               } else {
+                       domain->inet_family_hint = 3;
+               }
+               mono_domain_unlock (domain);
+       }
+       switch (domain->inet_family_hint) {
+       case 1: return PF_UNSPEC;
+       case 2: return PF_INET;
+       case 3: return PF_INET6;
+       default:
+               return PF_UNSPEC;
        }
 }
 #endif
@@ -630,8 +697,6 @@ gpointer ves_icall_System_Net_Sockets_Socket_Socket_internal(MonoObject *this, g
        
        MONO_ARCH_SAVE_REGS;
 
-       STASH_SYS_ASS(this);
-       
        *error = 0;
        
        sock_family=convert_family(family);
@@ -670,33 +735,6 @@ gpointer ves_icall_System_Net_Sockets_Socket_Socket_internal(MonoObject *this, g
        }
 #endif
 
-#ifndef PLATFORM_WIN32
-       /* .net seems to set this by default for SOCK_STREAM,
-        * not for SOCK_DGRAM (see bug #36322)
-        *
-        * It seems winsock has a rather different idea of what
-        * SO_REUSEADDR means.  If it's set, then a new socket can be
-        * bound over an existing listening socket.  There's a new
-        * windows-specific option called SO_EXCLUSIVEADDRUSE but
-        * using that means the socket MUST be closed properly, or a
-        * denial of service can occur.  Luckily for us, winsock
-        * behaves as though any other system would when SO_REUSEADDR
-        * is true, so we don't need to do anything else here.  See
-        * bug 53992.
-        */
-       {
-       int ret, true = 1;
-       
-       ret = _wapi_setsockopt (sock, SOL_SOCKET, SO_REUSEADDR, &true, sizeof (true));
-       if(ret==SOCKET_ERROR) {
-               *error = WSAGetLastError ();
-               
-               closesocket(sock);
-               return(NULL);
-       }
-       }
-#endif
-       
        return(GUINT_TO_POINTER (sock));
 }
 
@@ -817,7 +855,7 @@ static MonoObject *create_object_from_sockaddr(struct sockaddr *saddr,
        MonoAddressFamily family;
 
        /* Build a System.Net.SocketAddress object instance */
-       sockaddr_class=mono_class_from_name(system_assembly, "System.Net", "SocketAddress");
+       sockaddr_class=mono_class_from_name(get_socket_assembly (), "System.Net", "SocketAddress");
        sockaddr_obj=mono_object_new(domain, sockaddr_class);
        
        /* Locate the SocketAddress data buffer in the object */
@@ -1161,10 +1199,7 @@ ves_icall_System_Net_Sockets_Socket_Poll_internal (SOCKET sock, gint mode,
                                thread = mono_thread_current ();
                        }
                        
-                       mono_monitor_enter (thread->synch_lock);
-                       leave = ((thread->state & ThreadState_AbortRequested) != 0 ||
-                                (thread->state & ThreadState_StopRequested) != 0);
-                       mono_monitor_exit (thread->synch_lock);
+                       leave = mono_thread_test_state (thread, ThreadState_AbortRequested | ThreadState_StopRequested);
                        
                        if (leave != 0) {
                                g_free (pfds);
@@ -1223,6 +1258,92 @@ extern void ves_icall_System_Net_Sockets_Socket_Connect_internal(SOCKET sock, Mo
        g_free(sa);
 }
 
+/* These #defines from mswsock.h from wine.  Defining them here allows
+ * us to build this file on a mingw box that doesn't know the magic
+ * numbers, but still run on a newer windows box that does.
+ */
+#ifndef WSAID_DISCONNECTEX
+#define WSAID_DISCONNECTEX {0x7fda2e11,0x8630,0x436f,{0xa0, 0x31, 0xf5, 0x36, 0xa6, 0xee, 0xc1, 0x57}}
+typedef BOOL (WINAPI *LPFN_DISCONNECTEX)(SOCKET, LPOVERLAPPED, DWORD, DWORD);
+#endif
+
+#ifndef WSAID_TRANSMITFILE
+#define WSAID_TRANSMITFILE {0xb5367df0,0xcbac,0x11cf,{0x95,0xca,0x00,0x80,0x5f,0x48,0xa1,0x92}}
+typedef BOOL (WINAPI *LPFN_TRANSMITFILE)(SOCKET, HANDLE, DWORD, DWORD, LPOVERLAPPED, LPTRANSMIT_FILE_BUFFERS, DWORD);
+#endif
+
+extern void ves_icall_System_Net_Sockets_Socket_Disconnect_internal(SOCKET sock, MonoBoolean reuse, gint32 *error)
+{
+       int ret;
+       glong output_bytes = 0;
+       GUID disco_guid = WSAID_DISCONNECTEX;
+       GUID trans_guid = WSAID_TRANSMITFILE;
+       LPFN_DISCONNECTEX _wapi_disconnectex = NULL;
+       LPFN_TRANSMITFILE _wapi_transmitfile = NULL;
+       gboolean bret;
+       
+       MONO_ARCH_SAVE_REGS;
+
+       *error = 0;
+       
+#ifdef DEBUG
+       g_message("%s: disconnecting from socket %p (reuse %d)", __func__,
+                 sock, reuse);
+#endif
+
+       /* I _think_ the extension function pointers need to be looked
+        * up for each socket.  FIXME: check the best way to store
+        * pointers to functions in managed objects that still works
+        * on 64bit platforms.
+        */
+       ret = WSAIoctl (sock, SIO_GET_EXTENSION_FUNCTION_POINTER,
+                       (void *)&disco_guid, sizeof(GUID),
+                       (void *)&_wapi_disconnectex, sizeof(void *),
+                       &output_bytes, NULL, NULL);
+       if (ret != 0) {
+               /* make sure that WSAIoctl didn't put crap in the
+                * output pointer
+                */
+               _wapi_disconnectex = NULL;
+
+               /* Look up the TransmitFile extension function pointer
+                * instead of calling TransmitFile() directly, because
+                * apparently "Several of the extension functions have
+                * been available since WinSock 1.1 and are exported
+                * from MSWsock.dll, however it's not advisable to
+                * link directly to this dll as this ties you to the
+                * Microsoft WinSock provider. A provider neutral way
+                * of accessing these extension functions is to load
+                * them dynamically via WSAIoctl using the
+                * SIO_GET_EXTENSION_FUNCTION_POINTER op code. This
+                * should, theoretically, allow you to access these
+                * functions from any provider that supports them..." 
+                * (http://www.codeproject.com/internet/jbsocketserver3.asp)
+                */
+               ret = WSAIoctl (sock, SIO_GET_EXTENSION_FUNCTION_POINTER,
+                               (void *)&trans_guid, sizeof(GUID),
+                               (void *)&_wapi_transmitfile, sizeof(void *),
+                               &output_bytes, NULL, NULL);
+               if (ret != 0) {
+                       _wapi_transmitfile = NULL;
+               }
+       }
+
+       if (_wapi_disconnectex != NULL) {
+               bret = _wapi_disconnectex (sock, NULL, TF_REUSE_SOCKET, 0);
+       } else if (_wapi_transmitfile != NULL) {
+               bret = _wapi_transmitfile (sock, NULL, 0, 0, NULL, NULL,
+                                          TF_DISCONNECT | TF_REUSE_SOCKET);
+       } else {
+               *error = ERROR_NOT_SUPPORTED;
+               return;
+       }
+
+       if (bret == FALSE) {
+               *error = WSAGetLastError ();
+       }
+}
+
 gint32 ves_icall_System_Net_Sockets_Socket_Receive_internal(SOCKET sock, MonoArray *buffer, gint32 offset, gint32 count, gint32 flags, gint32 *error)
 {
        int ret;
@@ -1256,6 +1377,35 @@ gint32 ves_icall_System_Net_Sockets_Socket_Receive_internal(SOCKET sock, MonoArr
        return(ret);
 }
 
+gint32 ves_icall_System_Net_Sockets_Socket_Receive_array_internal(SOCKET sock, MonoArray *buffers, gint32 flags, gint32 *error)
+{
+       int ret, count;
+       DWORD recv;
+       WSABUF *wsabufs;
+       DWORD recvflags = 0;
+       
+       MONO_ARCH_SAVE_REGS;
+
+       *error = 0;
+       
+       wsabufs = mono_array_addr (buffers, WSABUF, 0);
+       count = mono_array_length (buffers);
+       
+       recvflags = convert_socketflags (flags);
+       if (recvflags == -1) {
+               *error = WSAEOPNOTSUPP;
+               return(0);
+       }
+       
+       ret = WSARecv (sock, wsabufs, count, &recv, &recvflags, NULL, NULL);
+       if (ret == SOCKET_ERROR) {
+               *error = WSAGetLastError ();
+               return(0);
+       }
+       
+       return(recv);
+}
+
 gint32 ves_icall_System_Net_Sockets_Socket_RecvFrom_internal(SOCKET sock, MonoArray *buffer, gint32 offset, gint32 count, gint32 flags, MonoObject **sockaddr, gint32 *error)
 {
        int ret;
@@ -1349,6 +1499,35 @@ gint32 ves_icall_System_Net_Sockets_Socket_Send_internal(SOCKET sock, MonoArray
        return(ret);
 }
 
+gint32 ves_icall_System_Net_Sockets_Socket_Send_array_internal(SOCKET sock, MonoArray *buffers, gint32 flags, gint32 *error)
+{
+       int ret, count;
+       DWORD sent;
+       WSABUF *wsabufs;
+       DWORD sendflags = 0;
+       
+       MONO_ARCH_SAVE_REGS;
+
+       *error = 0;
+       
+       wsabufs = mono_array_addr (buffers, WSABUF, 0);
+       count = mono_array_length (buffers);
+       
+       sendflags = convert_socketflags (flags);
+       if (sendflags == -1) {
+               *error = WSAEOPNOTSUPP;
+               return(0);
+       }
+       
+       ret = WSASend (sock, wsabufs, count, &sent, sendflags, NULL, NULL);
+       if (ret == SOCKET_ERROR) {
+               *error = WSAGetLastError ();
+               return(0);
+       }
+       
+       return(sent);
+}
+
 gint32 ves_icall_System_Net_Sockets_Socket_SendTo_internal(SOCKET sock, MonoArray *buffer, gint32 offset, gint32 count, gint32 flags, MonoObject *sockaddr, gint32 *error)
 {
        int ret;
@@ -1422,6 +1601,7 @@ void ves_icall_System_Net_Sockets_Socket_Select_internal(MonoArray **sockets, gi
        MonoClass *sock_arr_class;
        MonoArray *socks;
        time_t start;
+       mono_array_size_t socks_size;
        
        MONO_ARCH_SAVE_REGS;
 
@@ -1469,10 +1649,8 @@ void ves_icall_System_Net_Sockets_Socket_Select_internal(MonoArray **sockets, gi
                        if (thread == NULL)
                                thread = mono_thread_current ();
 
-                       mono_monitor_enter (thread->synch_lock);
-                       leave = ((thread->state & ThreadState_AbortRequested) != 0 || 
-                                (thread->state & ThreadState_StopRequested) != 0);
-                       mono_monitor_exit (thread->synch_lock);
+                       leave = mono_thread_test_state (thread, ThreadState_AbortRequested | ThreadState_StopRequested);
+                       
                        if (leave != 0) {
                                g_free (pfds);
                                *sockets = NULL;
@@ -1502,9 +1680,9 @@ void ves_icall_System_Net_Sockets_Socket_Select_internal(MonoArray **sockets, gi
        }
 
        sock_arr_class= ((MonoObject *)*sockets)->vtable->klass;
-       ret += 3; /* space for the NULL delimiters */
-       socks = mono_array_new_full (mono_domain_get (), sock_arr_class, (guint32*)&ret, NULL);
-       ret -= 3;
+       socks_size = ((mono_array_size_t)ret) + 3; /* space for the NULL delimiters */
+       socks = mono_array_new_full (mono_domain_get (), sock_arr_class, &socks_size, NULL);
+
        mode = idx = 0;
        for (i = 0; i < count && ret > 0; i++) {
                mono_pollfd *pfd;
@@ -1570,6 +1748,10 @@ void ves_icall_System_Net_Sockets_Socket_GetSocketOption_obj_internal(SOCKET soc
                *error = WSAENOPROTOOPT;
                return;
        }
+       if (ret == -2) {
+               *obj_val = int_to_object (domain, 0);
+               return;
+       }
        
        /* No need to deal with MulticastOption names here, because
         * you cant getsockopt AddMembership or DropMembership (the
@@ -1607,7 +1789,7 @@ void ves_icall_System_Net_Sockets_Socket_GetSocketOption_obj_internal(SOCKET soc
        switch(name) {
        case SocketOptionName_Linger:
                /* build a System.Net.Sockets.LingerOption */
-               obj_class=mono_class_from_name(system_assembly,
+               obj_class=mono_class_from_name(get_socket_assembly (),
                                               "System.Net.Sockets",
                                               "LingerOption");
                obj=mono_object_new(domain, obj_class);
@@ -1692,6 +1874,8 @@ void ves_icall_System_Net_Sockets_Socket_GetSocketOption_arr_internal(SOCKET soc
                *error = WSAENOPROTOOPT;
                return;
        }
+       if(ret==-2)
+               return;
 
        valsize=mono_array_length(*byte_val);
        buf=mono_array_addr(*byte_val, guchar, 0);
@@ -1712,10 +1896,9 @@ static struct in_addr ipaddress_to_struct_in_addr(MonoObject *ipaddr)
 
        /* No idea why .net uses a 64bit type to hold a 32bit value...
         *
-        * Internal value of IPAddess is in Network Order, there is no need
-        * to call htonl here.
+        * Internal value of IPAddess is in little-endian order
         */
-       inaddr.s_addr=(guint32)*(guint64 *)(((char *)ipaddr)+field->offset);
+       inaddr.s_addr=GUINT_FROM_LE ((guint32)*(guint64 *)(((char *)ipaddr)+field->offset));
        
        return(inaddr);
 }
@@ -1787,6 +1970,9 @@ void ves_icall_System_Net_Sockets_Socket_SetSocketOption_internal(SOCKET sock, g
                *error = WSAENOPROTOOPT;
                return;
        }
+       if(ret==-2){
+               return;
+       }
 
        /* Only one of obj_val, byte_val or int_val has data */
        if(obj_val!=NULL) {
@@ -1900,7 +2086,21 @@ void ves_icall_System_Net_Sockets_Socket_SetSocketOption_internal(SOCKET sock, g
                }
        } else {
                /* ReceiveTimeout/SendTimeout get here */
-               ret = _wapi_setsockopt (sock, system_level, system_name, (char *) &int_val, sizeof (int_val));
+               switch(name) {
+               case SocketOptionName_DontFragment:
+#ifdef HAVE_IP_MTU_DISCOVER
+                       /* Fiddle with the value slightly if we're
+                        * turning DF on
+                        */
+                       if (int_val == 1) {
+                               int_val = IP_PMTUDISC_DO;
+                       }
+                       /* Fall through */
+#endif
+                       
+               default:
+                       ret = _wapi_setsockopt (sock, system_level, system_name, (char *) &int_val, sizeof (int_val));
+               }
        }
 
        if(ret==SOCKET_ERROR) {
@@ -2645,7 +2845,6 @@ inet_pton (int family, const char *address, void *inaddrp)
 extern MonoBoolean ves_icall_System_Net_Dns_GetHostByAddr_internal(MonoString *addr, MonoString **h_name, MonoArray **h_aliases, MonoArray **h_addr_list)
 {
        char *address;
-       const char *version;
        gboolean v1;
        
 #ifdef AF_INET6
@@ -2661,10 +2860,7 @@ extern MonoBoolean ves_icall_System_Net_Dns_GetHostByAddr_internal(MonoString *a
        gboolean ret;
 #endif
 
-       MONO_ARCH_SAVE_REGS;
-
-       version = mono_get_runtime_info ()->framework_version;
-       v1 = (version[0] == '1');
+       v1 = mono_framework_version () == 1;
 
        address = mono_string_to_utf8 (addr);
 
@@ -2691,12 +2887,18 @@ extern MonoBoolean ves_icall_System_Net_Dns_GetHostByAddr_internal(MonoString *a
        }
        
        if(family == AF_INET) {
+#if HAVE_SOCKADDR_IN_SIN_LEN
+               saddr.sin_len = sizeof (saddr);
+#endif
                if(getnameinfo ((struct sockaddr*)&saddr, sizeof(saddr),
                                hostname, sizeof(hostname), NULL, 0,
                                flags) != 0) {
                        return(FALSE);
                }
        } else if(family == AF_INET6) {
+#if HAVE_SOCKADDR_IN6_SIN_LEN
+               saddr6.sin6_len = sizeof (saddr6);
+#endif
                if(getnameinfo ((struct sockaddr*)&saddr6, sizeof(saddr6),
                                hostname, sizeof(hostname), NULL, 0,
                                flags) != 0) {