2007-09-02 Zoltan Varga <vargaz@gmail.com>
[mono.git] / mono / io-layer / sockets.c
index a668121143c94bf8c464b9a20ac50832cd9cdd92..33a59810354feacc03655eee88e394c44a6c4f7a 100644 (file)
@@ -15,6 +15,7 @@
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/ioctl.h>
+#include <sys/poll.h>
 #ifdef HAVE_SYS_FILIO_H
 #include <sys/filio.h>     /* defines FIONBIO and FIONREAD */
 #endif
@@ -37,8 +38,6 @@
 #undef DEBUG
 
 static guint32 startup_count=0;
-static pthread_key_t error_key;
-static mono_once_t error_key_once=MONO_ONCE_INIT;
 
 static void socket_close (gpointer handle, gpointer data);
 
@@ -71,6 +70,11 @@ static void socket_close (gpointer handle, gpointer data G_GNUC_UNUSED)
                return;
        }
 
+       /* Shutdown the socket for reading, to interrupt any potential
+        * receives that may be blocking for data.  See bug 75705.
+        */
+       shutdown (GPOINTER_TO_UINT (handle), SHUT_RD);
+       
        do {
                ret = close (GPOINTER_TO_UINT(handle));
        } while (ret == -1 && errno == EINTR &&
@@ -137,33 +141,14 @@ int WSACleanup(void)
        return(0);
 }
 
-static void error_init(void)
-{
-       int ret;
-       
-       ret = pthread_key_create (&error_key, NULL);
-       g_assert (ret == 0);
-}
-
 void WSASetLastError(int error)
 {
-       int ret;
-       
-       mono_once (&error_key_once, error_init);
-       ret = pthread_setspecific (error_key, GINT_TO_POINTER(error));
-       g_assert (ret == 0);
+       SetLastError (error);
 }
 
 int WSAGetLastError(void)
 {
-       int err;
-       void *errptr;
-       
-       mono_once (&error_key_once, error_init);
-       errptr = pthread_getspecific (error_key);
-       err = GPOINTER_TO_INT(errptr);
-       
-       return(err);
+       return(GetLastError ());
 }
 
 int closesocket(guint32 fd)
@@ -183,18 +168,35 @@ guint32 _wapi_accept(guint32 fd, struct sockaddr *addr, socklen_t *addrlen)
 {
        gpointer handle = GUINT_TO_POINTER (fd);
        gpointer new_handle;
+       struct _WapiHandle_socket *socket_handle;
+       struct _WapiHandle_socket new_socket_handle = {0};
+       gboolean ok;
        int new_fd;
        
        if (startup_count == 0) {
                WSASetLastError (WSANOTINITIALISED);
                return(INVALID_SOCKET);
        }
+
+       if (addr != NULL && *addrlen < sizeof(struct sockaddr)) {
+               WSASetLastError (WSAEFAULT);
+               return(INVALID_SOCKET);
+       }
        
        if (_wapi_handle_type (handle) != WAPI_HANDLE_SOCKET) {
                WSASetLastError (WSAENOTSOCK);
                return(INVALID_SOCKET);
        }
        
+       ok = _wapi_lookup_handle (handle, WAPI_HANDLE_SOCKET,
+                                 (gpointer *)&socket_handle);
+       if (ok == FALSE) {
+               g_warning ("%s: error looking up socket handle %p",
+                          __func__, handle);
+               WSASetLastError (WSAENOTSOCK);
+               return(INVALID_SOCKET);
+       }
+       
        do {
                new_fd = accept (fd, addr, addrlen);
        } while (new_fd == -1 && errno == EINTR &&
@@ -224,7 +226,13 @@ guint32 _wapi_accept(guint32 fd, struct sockaddr *addr, socklen_t *addrlen)
                return(INVALID_SOCKET);
        }
 
-       new_handle = _wapi_handle_new_fd (WAPI_HANDLE_SOCKET, new_fd, NULL);
+       new_socket_handle.domain = socket_handle->domain;
+       new_socket_handle.type = socket_handle->type;
+       new_socket_handle.protocol = socket_handle->protocol;
+       new_socket_handle.still_readable = 1;
+
+       new_handle = _wapi_handle_new_fd (WAPI_HANDLE_SOCKET, new_fd,
+                                         &new_socket_handle);
        if(new_handle == _WAPI_HANDLE_INVALID) {
                g_warning ("%s: error creating socket handle", __func__);
                WSASetLastError (ERROR_GEN_FAILURE);
@@ -272,7 +280,8 @@ int _wapi_connect(guint32 fd, const struct sockaddr *serv_addr,
                  socklen_t addrlen)
 {
        gpointer handle = GUINT_TO_POINTER (fd);
-       int ret;
+       struct _WapiHandle_socket *socket_handle;
+       gboolean ok;
        gint errnum;
        
        if (startup_count == 0) {
@@ -285,26 +294,82 @@ int _wapi_connect(guint32 fd, const struct sockaddr *serv_addr,
                return(SOCKET_ERROR);
        }
        
-       do {
-               ret = connect (fd, serv_addr, addrlen);
-       } while (ret==-1 && errno==EINTR && !_wapi_thread_cur_apc_pending());
-
-       if (ret == -1) {
+       if (connect (fd, serv_addr, addrlen) == -1) {
+               struct pollfd fds;
+               int so_error;
+               socklen_t len;
+               
                errnum = errno;
                
+               if (errno != EINTR) {
 #ifdef DEBUG
-               g_message ("%s: connect error: %s", __func__,
-                          strerror (errnum));
+                       g_message ("%s: connect error: %s", __func__,
+                                  strerror (errnum));
 #endif
-               errnum = errno_to_WSA (errnum, __func__);
-               if (errnum == WSAEINPROGRESS)
-                       errnum = WSAEWOULDBLOCK; /* see bug #73053 */
 
-               WSASetLastError (errnum);
+                       errnum = errno_to_WSA (errnum, __func__);
+                       if (errnum == WSAEINPROGRESS)
+                               errnum = WSAEWOULDBLOCK; /* see bug #73053 */
+
+                       WSASetLastError (errnum);
                
-               return(SOCKET_ERROR);
+                       return(SOCKET_ERROR);
+               }
+
+               fds.fd = fd;
+               fds.events = POLLOUT;
+               while (poll (&fds, 1, -1) == -1 &&
+                      !_wapi_thread_cur_apc_pending ()) {
+                       if (errno != EINTR) {
+                               errnum = errno_to_WSA (errno, __func__);
+
+#ifdef DEBUG
+                               g_message ("%s: connect poll error: %s",
+                                          __func__, strerror (errno));
+#endif
+
+                               WSASetLastError (errnum);
+                               return(SOCKET_ERROR);
+                       }
+               }
+
+               len = sizeof(so_error);
+               if (getsockopt (fd, SOL_SOCKET, SO_ERROR, &so_error,
+                               &len) == -1) {
+                       errnum = errno_to_WSA (errno, __func__);
+
+#ifdef DEBUG
+                       g_message ("%s: connect getsockopt error: %s",
+                                  __func__, strerror (errno));
+#endif
+
+                       WSASetLastError (errnum);
+                       return(SOCKET_ERROR);
+               }
+               
+               if (so_error != 0) {
+                       errnum = errno_to_WSA (so_error, __func__);
+
+                       /* Need to save this socket error */
+                       ok = _wapi_lookup_handle (handle, WAPI_HANDLE_SOCKET,
+                                                 (gpointer *)&socket_handle);
+                       if (ok == FALSE) {
+                               g_warning ("%s: error looking up socket handle %p", __func__, handle);
+                       } else {
+                               socket_handle->saved_error = errnum;
+                       }
+                       
+#ifdef DEBUG
+                       g_message ("%s: connect getsockopt returned error: %s",
+                                  __func__, strerror (so_error));
+#endif
+
+                       WSASetLastError (errnum);
+                       return(SOCKET_ERROR);
+               }
        }
-       return(ret);
+               
+       return(0);
 }
 
 int _wapi_getpeername(guint32 fd, struct sockaddr *name, socklen_t *namelen)
@@ -392,7 +457,8 @@ int _wapi_getsockopt(guint32 fd, int level, int optname, void *optval,
        }
 
        tmp_val = optval;
-       if (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO) {
+       if (level == SOL_SOCKET &&
+           (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO)) {
                tmp_val = &tv;
                *optlen = sizeof (tv);
        }
@@ -411,7 +477,8 @@ int _wapi_getsockopt(guint32 fd, int level, int optname, void *optval,
                return(SOCKET_ERROR);
        }
 
-       if (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO) {
+       if (level == SOL_SOCKET &&
+           (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO)) {
                *((int *) optval)  = tv.tv_sec * 1000 + (tv.tv_usec / 1000);    // milli from micro
                *optlen = sizeof (int);
        }
@@ -480,6 +547,8 @@ int _wapi_recvfrom(guint32 fd, void *buf, size_t len, int recv_flags,
                   struct sockaddr *from, socklen_t *fromlen)
 {
        gpointer handle = GUINT_TO_POINTER (fd);
+       struct _WapiHandle_socket *socket_handle;
+       gboolean ok;
        int ret;
        
        if (startup_count == 0) {
@@ -497,6 +566,33 @@ int _wapi_recvfrom(guint32 fd, void *buf, size_t len, int recv_flags,
        } while (ret == -1 && errno == EINTR &&
                 !_wapi_thread_cur_apc_pending ());
 
+       if (ret == 0 && len > 0) {
+               /* According to the Linux man page, recvfrom only
+                * returns 0 when the socket has been shut down
+                * cleanly.  Turn this into an EINTR to simulate win32
+                * behaviour of returning EINTR when a socket is
+                * closed while the recvfrom is blocking (we use a
+                * shutdown() in socket_close() to trigger this.) See
+                * bug 75705.
+                */
+               /* Distinguish between the socket being shut down at
+                * the local or remote ends, and reads that request 0
+                * bytes to be read
+                */
+
+               /* If this returns FALSE, it means the socket has been
+                * closed locally.  If it returns TRUE, but
+                * still_readable != 1 then shutdown
+                * (SHUT_RD|SHUT_RDWR) has been called locally.
+                */
+               ok = _wapi_lookup_handle (handle, WAPI_HANDLE_SOCKET,
+                                         (gpointer *)&socket_handle);
+               if (ok == FALSE || socket_handle->still_readable != 1) {
+                       ret = -1;
+                       errno = EINTR;
+               }
+       }
+       
        if (ret == -1) {
                gint errnum = errno;
 #ifdef DEBUG
@@ -599,14 +695,16 @@ int _wapi_setsockopt(guint32 fd, int level, int optname,
        }
 
        tmp_val = optval;
-       if (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO) {
+       if (level == SOL_SOCKET &&
+           (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO)) {
                int ms = *((int *) optval);
                tv.tv_sec = ms / 1000;
                tv.tv_usec = (ms % 1000) * 1000;        // micro from milli
                tmp_val = &tv;
                optlen = sizeof (tv);
 #if defined (__linux__)
-       } else if (optname == SO_SNDBUF || optname == SO_RCVBUF) {
+       } else if (level == SOL_SOCKET &&
+                  (optname == SO_SNDBUF || optname == SO_RCVBUF)) {
                /* According to socket(7) the Linux kernel doubles the
                 * buffer sizes "to allow space for bookkeeping
                 * overhead."
@@ -637,6 +735,8 @@ int _wapi_setsockopt(guint32 fd, int level, int optname,
 
 int _wapi_shutdown(guint32 fd, int how)
 {
+       struct _WapiHandle_socket *socket_handle;
+       gboolean ok;
        gpointer handle = GUINT_TO_POINTER (fd);
        int ret;
        
@@ -649,6 +749,20 @@ int _wapi_shutdown(guint32 fd, int how)
                WSASetLastError (WSAENOTSOCK);
                return(SOCKET_ERROR);
        }
+
+       if (how == SHUT_RD ||
+           how == SHUT_RDWR) {
+               ok = _wapi_lookup_handle (handle, WAPI_HANDLE_SOCKET,
+                                         (gpointer *)&socket_handle);
+               if (ok == FALSE) {
+                       g_warning ("%s: error looking up socket handle %p",
+                                  __func__, handle);
+                       WSASetLastError (WSAENOTSOCK);
+                       return(SOCKET_ERROR);
+               }
+               
+               socket_handle->still_readable = 0;
+       }
        
        ret = shutdown (fd, how);
        if (ret == -1) {
@@ -677,6 +791,7 @@ guint32 _wapi_socket(int domain, int type, int protocol, void *unused,
        socket_handle.domain = domain;
        socket_handle.type = type;
        socket_handle.protocol = protocol;
+       socket_handle.still_readable = 1;
        
        fd = socket (domain, type, protocol);
        if (fd == -1 && domain == AF_INET && type == SOCK_RAW &&
@@ -708,6 +823,40 @@ guint32 _wapi_socket(int domain, int type, int protocol, void *unused,
                
                return(INVALID_SOCKET);
        }
+
+       /* .net seems to set this by default for SOCK_STREAM, not for
+        * SOCK_DGRAM (see bug #36322)
+        *
+        * It seems winsock has a rather different idea of what
+        * SO_REUSEADDR means.  If it's set, then a new socket can be
+        * bound over an existing listening socket.  There's a new
+        * windows-specific option called SO_EXCLUSIVEADDRUSE but
+        * using that means the socket MUST be closed properly, or a
+        * denial of service can occur.  Luckily for us, winsock
+        * behaves as though any other system would when SO_REUSEADDR
+        * is true, so we don't need to do anything else here.  See
+        * bug 53992.
+        */
+       {
+               int ret, true = 1;
+       
+               ret = setsockopt (fd, SOL_SOCKET, SO_REUSEADDR, &true,
+                                 sizeof (true));
+               if (ret == -1) {
+                       int errnum = errno;
+
+#ifdef DEBUG
+                       g_message ("%s: Error setting SO_REUSEADDR", __func__);
+#endif
+                       
+                       errnum = errno_to_WSA (errnum, __func__);
+                       WSASetLastError (errnum);
+
+                       close (fd);
+
+                       return(INVALID_SOCKET);                 
+               }
+       }
        
        
        mono_once (&socket_ops_once, socket_ops_init);
@@ -715,6 +864,8 @@ guint32 _wapi_socket(int domain, int type, int protocol, void *unused,
        handle = _wapi_handle_new_fd (WAPI_HANDLE_SOCKET, fd, &socket_handle);
        if (handle == _WAPI_HANDLE_INVALID) {
                g_warning ("%s: error creating socket handle", __func__);
+               WSASetLastError (WSASYSCALLFAILURE);
+               close (fd);
                return(INVALID_SOCKET);
        }