rxrpc: Use rxrpc_extract_addr_from_skb() rather than doing this manually
[cascardo/linux.git] / net / rxrpc / output.c
1 /* RxRPC packet transmission
2  *
3  * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved.
4  * Written by David Howells (dhowells@redhat.com)
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
13
14 #include <linux/net.h>
15 #include <linux/gfp.h>
16 #include <linux/skbuff.h>
17 #include <linux/export.h>
18 #include <net/sock.h>
19 #include <net/af_rxrpc.h>
20 #include "ar-internal.h"
21
22 struct rxrpc_pkt_buffer {
23         struct rxrpc_wire_header whdr;
24         union {
25                 struct {
26                         struct rxrpc_ackpacket ack;
27                         u8 acks[255];
28                         u8 pad[3];
29                 };
30                 __be32 abort_code;
31         };
32         struct rxrpc_ackinfo ackinfo;
33 };
34
35 /*
36  * Fill out an ACK packet.
37  */
38 static size_t rxrpc_fill_out_ack(struct rxrpc_call *call,
39                                  struct rxrpc_pkt_buffer *pkt)
40 {
41         rxrpc_seq_t hard_ack, top, seq;
42         int ix;
43         u32 mtu, jmax;
44         u8 *ackp = pkt->acks;
45
46         /* Barrier against rxrpc_input_data(). */
47         hard_ack = READ_ONCE(call->rx_hard_ack);
48         top = smp_load_acquire(&call->rx_top);
49
50         pkt->ack.bufferSpace    = htons(8);
51         pkt->ack.maxSkew        = htons(call->ackr_skew);
52         pkt->ack.firstPacket    = htonl(hard_ack + 1);
53         pkt->ack.previousPacket = htonl(call->ackr_prev_seq);
54         pkt->ack.serial         = htonl(call->ackr_serial);
55         pkt->ack.reason         = call->ackr_reason;
56         pkt->ack.nAcks          = top - hard_ack;
57
58         if (after(top, hard_ack)) {
59                 seq = hard_ack + 1;
60                 do {
61                         ix = seq & RXRPC_RXTX_BUFF_MASK;
62                         if (call->rxtx_buffer[ix])
63                                 *ackp++ = RXRPC_ACK_TYPE_ACK;
64                         else
65                                 *ackp++ = RXRPC_ACK_TYPE_NACK;
66                         seq++;
67                 } while (before_eq(seq, top));
68         }
69
70         mtu = call->conn->params.peer->if_mtu;
71         mtu -= call->conn->params.peer->hdrsize;
72         jmax = (call->nr_jumbo_bad > 3) ? 1 : rxrpc_rx_jumbo_max;
73         pkt->ackinfo.rxMTU      = htonl(rxrpc_rx_mtu);
74         pkt->ackinfo.maxMTU     = htonl(mtu);
75         pkt->ackinfo.rwind      = htonl(call->rx_winsize);
76         pkt->ackinfo.jumbo_max  = htonl(jmax);
77
78         *ackp++ = 0;
79         *ackp++ = 0;
80         *ackp++ = 0;
81         return top - hard_ack + 3;
82 }
83
84 /*
85  * Send an ACK or ABORT call packet.
86  */
87 int rxrpc_send_call_packet(struct rxrpc_call *call, u8 type)
88 {
89         struct rxrpc_connection *conn = NULL;
90         struct rxrpc_pkt_buffer *pkt;
91         struct msghdr msg;
92         struct kvec iov[2];
93         rxrpc_serial_t serial;
94         size_t len, n;
95         int ioc, ret;
96         u32 abort_code;
97
98         _enter("%u,%s", call->debug_id, rxrpc_pkts[type]);
99
100         spin_lock_bh(&call->lock);
101         if (call->conn)
102                 conn = rxrpc_get_connection_maybe(call->conn);
103         spin_unlock_bh(&call->lock);
104         if (!conn)
105                 return -ECONNRESET;
106
107         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
108         if (!pkt) {
109                 rxrpc_put_connection(conn);
110                 return -ENOMEM;
111         }
112
113         serial = atomic_inc_return(&conn->serial);
114
115         msg.msg_name    = &call->peer->srx.transport;
116         msg.msg_namelen = call->peer->srx.transport_len;
117         msg.msg_control = NULL;
118         msg.msg_controllen = 0;
119         msg.msg_flags   = 0;
120
121         pkt->whdr.epoch         = htonl(conn->proto.epoch);
122         pkt->whdr.cid           = htonl(call->cid);
123         pkt->whdr.callNumber    = htonl(call->call_id);
124         pkt->whdr.seq           = 0;
125         pkt->whdr.serial        = htonl(serial);
126         pkt->whdr.type          = type;
127         pkt->whdr.flags         = conn->out_clientflag;
128         pkt->whdr.userStatus    = 0;
129         pkt->whdr.securityIndex = call->security_ix;
130         pkt->whdr._rsvd         = 0;
131         pkt->whdr.serviceId     = htons(call->service_id);
132
133         iov[0].iov_base = pkt;
134         iov[0].iov_len  = sizeof(pkt->whdr);
135         len = sizeof(pkt->whdr);
136
137         switch (type) {
138         case RXRPC_PACKET_TYPE_ACK:
139                 spin_lock_bh(&call->lock);
140                 n = rxrpc_fill_out_ack(call, pkt);
141                 call->ackr_reason = 0;
142
143                 spin_unlock_bh(&call->lock);
144
145                 _proto("Tx ACK %%%u { m=%hu f=#%u p=#%u s=%%%u r=%s n=%u }",
146                        serial,
147                        ntohs(pkt->ack.maxSkew),
148                        ntohl(pkt->ack.firstPacket),
149                        ntohl(pkt->ack.previousPacket),
150                        ntohl(pkt->ack.serial),
151                        rxrpc_acks(pkt->ack.reason),
152                        pkt->ack.nAcks);
153
154                 iov[0].iov_len += sizeof(pkt->ack) + n;
155                 iov[1].iov_base = &pkt->ackinfo;
156                 iov[1].iov_len  = sizeof(pkt->ackinfo);
157                 len += sizeof(pkt->ack) + n + sizeof(pkt->ackinfo);
158                 ioc = 2;
159                 break;
160
161         case RXRPC_PACKET_TYPE_ABORT:
162                 abort_code = call->abort_code;
163                 pkt->abort_code = htonl(abort_code);
164                 _proto("Tx ABORT %%%u { %d }", serial, abort_code);
165                 iov[0].iov_len += sizeof(pkt->abort_code);
166                 len += sizeof(pkt->abort_code);
167                 ioc = 1;
168                 break;
169
170         default:
171                 BUG();
172                 ret = -ENOANO;
173                 goto out;
174         }
175
176         ret = kernel_sendmsg(conn->params.local->socket,
177                              &msg, iov, ioc, len);
178
179         if (ret < 0 && call->state < RXRPC_CALL_COMPLETE) {
180                 switch (pkt->whdr.type) {
181                 case RXRPC_PACKET_TYPE_ACK:
182                         rxrpc_propose_ACK(call, pkt->ack.reason,
183                                           ntohs(pkt->ack.maxSkew),
184                                           ntohl(pkt->ack.serial),
185                                           true, true);
186                         break;
187                 case RXRPC_PACKET_TYPE_ABORT:
188                         break;
189                 }
190         }
191
192 out:
193         rxrpc_put_connection(conn);
194         kfree(pkt);
195         return ret;
196 }
197
198 /*
199  * send a packet through the transport endpoint
200  */
201 int rxrpc_send_data_packet(struct rxrpc_connection *conn, struct sk_buff *skb)
202 {
203         struct kvec iov[1];
204         struct msghdr msg;
205         int ret, opt;
206
207         _enter(",{%d}", skb->len);
208
209         iov[0].iov_base = skb->head;
210         iov[0].iov_len = skb->len;
211
212         msg.msg_name = &conn->params.peer->srx.transport;
213         msg.msg_namelen = conn->params.peer->srx.transport_len;
214         msg.msg_control = NULL;
215         msg.msg_controllen = 0;
216         msg.msg_flags = 0;
217
218         /* send the packet with the don't fragment bit set if we currently
219          * think it's small enough */
220         if (skb->len - sizeof(struct rxrpc_wire_header) < conn->params.peer->maxdata) {
221                 down_read(&conn->params.local->defrag_sem);
222                 /* send the packet by UDP
223                  * - returns -EMSGSIZE if UDP would have to fragment the packet
224                  *   to go out of the interface
225                  *   - in which case, we'll have processed the ICMP error
226                  *     message and update the peer record
227                  */
228                 ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 1,
229                                      iov[0].iov_len);
230
231                 up_read(&conn->params.local->defrag_sem);
232                 if (ret == -EMSGSIZE)
233                         goto send_fragmentable;
234
235                 _leave(" = %d [%u]", ret, conn->params.peer->maxdata);
236                 return ret;
237         }
238
239 send_fragmentable:
240         /* attempt to send this message with fragmentation enabled */
241         _debug("send fragment");
242
243         down_write(&conn->params.local->defrag_sem);
244
245         switch (conn->params.local->srx.transport.family) {
246         case AF_INET:
247                 opt = IP_PMTUDISC_DONT;
248                 ret = kernel_setsockopt(conn->params.local->socket,
249                                         SOL_IP, IP_MTU_DISCOVER,
250                                         (char *)&opt, sizeof(opt));
251                 if (ret == 0) {
252                         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 1,
253                                              iov[0].iov_len);
254
255                         opt = IP_PMTUDISC_DO;
256                         kernel_setsockopt(conn->params.local->socket, SOL_IP,
257                                           IP_MTU_DISCOVER,
258                                           (char *)&opt, sizeof(opt));
259                 }
260                 break;
261         }
262
263         up_write(&conn->params.local->defrag_sem);
264         _leave(" = %d [frag %u]", ret, conn->params.peer->maxdata);
265         return ret;
266 }
267
268 /*
269  * reject packets through the local endpoint
270  */
271 void rxrpc_reject_packets(struct rxrpc_local *local)
272 {
273         struct sockaddr_rxrpc srx;
274         struct rxrpc_skb_priv *sp;
275         struct rxrpc_wire_header whdr;
276         struct sk_buff *skb;
277         struct msghdr msg;
278         struct kvec iov[2];
279         size_t size;
280         __be32 code;
281
282         _enter("%d", local->debug_id);
283
284         iov[0].iov_base = &whdr;
285         iov[0].iov_len = sizeof(whdr);
286         iov[1].iov_base = &code;
287         iov[1].iov_len = sizeof(code);
288         size = sizeof(whdr) + sizeof(code);
289
290         msg.msg_name = &srx.transport;
291         msg.msg_control = NULL;
292         msg.msg_controllen = 0;
293         msg.msg_flags = 0;
294
295         memset(&whdr, 0, sizeof(whdr));
296         whdr.type = RXRPC_PACKET_TYPE_ABORT;
297
298         while ((skb = skb_dequeue(&local->reject_queue))) {
299                 rxrpc_see_skb(skb);
300                 sp = rxrpc_skb(skb);
301
302                 if (rxrpc_extract_addr_from_skb(&srx, skb) == 0) {
303                         msg.msg_namelen = srx.transport_len;
304
305                         code = htonl(skb->priority);
306
307                         whdr.epoch      = htonl(sp->hdr.epoch);
308                         whdr.cid        = htonl(sp->hdr.cid);
309                         whdr.callNumber = htonl(sp->hdr.callNumber);
310                         whdr.serviceId  = htons(sp->hdr.serviceId);
311                         whdr.flags      = sp->hdr.flags;
312                         whdr.flags      ^= RXRPC_CLIENT_INITIATED;
313                         whdr.flags      &= RXRPC_CLIENT_INITIATED;
314
315                         kernel_sendmsg(local->socket, &msg, iov, 2, size);
316                 }
317
318                 rxrpc_free_skb(skb);
319         }
320
321         _leave("");
322 }