dlm: fix race while closing connections
[cascardo/linux.git] / fs / dlm / lowcomms.c
index 754fd6c..749deb3 100644 (file)
@@ -514,17 +514,24 @@ static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
 }
 
 /* Close a remote connection and tidy up */
-static void close_connection(struct connection *con, bool and_other)
+static void close_connection(struct connection *con, bool and_other,
+                            bool tx, bool rx)
 {
-       mutex_lock(&con->sock_mutex);
+       clear_bit(CF_CONNECT_PENDING, &con->flags);
+       clear_bit(CF_WRITE_PENDING, &con->flags);
+       if (tx && cancel_work_sync(&con->swork))
+               log_print("canceled swork for node %d", con->nodeid);
+       if (rx && cancel_work_sync(&con->rwork))
+               log_print("canceled rwork for node %d", con->nodeid);
 
+       mutex_lock(&con->sock_mutex);
        if (con->sock) {
                sock_release(con->sock);
                con->sock = NULL;
        }
        if (con->othercon && and_other) {
                /* Will only re-enter once. */
-               close_connection(con->othercon, false);
+               close_connection(con->othercon, false, true, true);
        }
        if (con->rx_page) {
                __free_page(con->rx_page);
@@ -535,7 +542,9 @@ static void close_connection(struct connection *con, bool and_other)
        mutex_unlock(&con->sock_mutex);
 }
 
-/* We only send shutdown messages to nodes that are not part of the cluster */
+/* We only send shutdown messages to nodes that are not part of the cluster
+ * or if we get multiple connections from a node.
+ */
 static void sctp_send_shutdown(sctp_assoc_t associd)
 {
        static char outcmsg[CMSG_SPACE(sizeof(struct sctp_sndrcvinfo))];
@@ -718,6 +727,14 @@ static void process_sctp_notification(struct connection *con,
                        if (!new_con)
                                return;
 
+                       if (new_con->sock) {
+                               log_print("reject connect from node %d: "
+                                         "already has a connection.",
+                                         nodeid);
+                               sctp_send_shutdown(prim.ssp_assoc_id);
+                               return;
+                       }
+
                        /* Peel off a new sock */
                        lock_sock(con->sock->sk);
                        ret = sctp_do_peeloff(con->sock->sk,
@@ -892,7 +909,7 @@ out_resched:
 out_close:
        mutex_unlock(&con->sock_mutex);
        if (ret != -EAGAIN) {
-               close_connection(con, false);
+               close_connection(con, false, true, false);
                /* Reconnect when there is something to send */
        }
        /* Don't return success if we really got EOF */
@@ -1612,7 +1629,7 @@ out:
 
 send_error:
        mutex_unlock(&con->sock_mutex);
-       close_connection(con, false);
+       close_connection(con, false, false, true);
        lowcomms_connect_sock(con);
        return;
 
@@ -1644,15 +1661,9 @@ int dlm_lowcomms_close(int nodeid)
        log_print("closing connection to node %d", nodeid);
        con = nodeid2con(nodeid, 0);
        if (con) {
-               clear_bit(CF_CONNECT_PENDING, &con->flags);
-               clear_bit(CF_WRITE_PENDING, &con->flags);
                set_bit(CF_CLOSE, &con->flags);
-               if (cancel_work_sync(&con->swork))
-                       log_print("canceled swork for node %d", nodeid);
-               if (cancel_work_sync(&con->rwork))
-                       log_print("canceled rwork for node %d", nodeid);
+               close_connection(con, true, true, true);
                clean_one_writequeue(con);
-               close_connection(con, true);
        }
 
        spin_lock(&dlm_node_addrs_spin);
@@ -1735,7 +1746,7 @@ static void stop_conn(struct connection *con)
 
 static void free_conn(struct connection *con)
 {
-       close_connection(con, true);
+       close_connection(con, true, true, true);
        if (con->othercon)
                kmem_cache_free(con_cache, con->othercon);
        hlist_del(&con->list);
@@ -1806,7 +1817,7 @@ fail_unlisten:
        dlm_allow_conn = 0;
        con = nodeid2con(0,0);
        if (con) {
-               close_connection(con, false);
+               close_connection(con, false, true, true);
                kmem_cache_free(con_cache, con);
        }
 fail_destroy: