virtio_transport_common.c 29 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * common code for virtio vsock
  4. *
  5. * Copyright (C) 2013-2015 Red Hat, Inc.
  6. * Author: Asias He <asias@redhat.com>
  7. * Stefan Hajnoczi <stefanha@redhat.com>
  8. */
  9. #include <linux/spinlock.h>
  10. #include <linux/module.h>
  11. #include <linux/sched/signal.h>
  12. #include <linux/ctype.h>
  13. #include <linux/list.h>
  14. #include <linux/virtio_vsock.h>
  15. #include <uapi/linux/vsockmon.h>
  16. #include <net/sock.h>
  17. #include <net/af_vsock.h>
  18. #define CREATE_TRACE_POINTS
  19. #include <trace/events/vsock_virtio_transport_common.h>
  20. /* How long to wait for graceful shutdown of a connection */
  21. #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  22. /* Threshold for detecting small packets to copy */
  23. #define GOOD_COPY_LEN 128
  24. uint virtio_transport_max_vsock_pkt_buf_size = 64 * 1024;
  25. module_param(virtio_transport_max_vsock_pkt_buf_size, uint, 0444);
  26. EXPORT_SYMBOL_GPL(virtio_transport_max_vsock_pkt_buf_size);
  27. static const struct virtio_transport *
  28. virtio_transport_get_ops(struct vsock_sock *vsk)
  29. {
  30. const struct vsock_transport *t = vsock_core_get_transport(vsk);
  31. if (WARN_ON(!t))
  32. return NULL;
  33. return container_of(t, struct virtio_transport, transport);
  34. }
  35. static struct virtio_vsock_pkt *
  36. virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  37. size_t len,
  38. u32 src_cid,
  39. u32 src_port,
  40. u32 dst_cid,
  41. u32 dst_port)
  42. {
  43. struct virtio_vsock_pkt *pkt;
  44. int err;
  45. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  46. if (!pkt)
  47. return NULL;
  48. pkt->hdr.type = cpu_to_le16(info->type);
  49. pkt->hdr.op = cpu_to_le16(info->op);
  50. pkt->hdr.src_cid = cpu_to_le64(src_cid);
  51. pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
  52. pkt->hdr.src_port = cpu_to_le32(src_port);
  53. pkt->hdr.dst_port = cpu_to_le32(dst_port);
  54. pkt->hdr.flags = cpu_to_le32(info->flags);
  55. pkt->len = len;
  56. pkt->hdr.len = cpu_to_le32(len);
  57. pkt->reply = info->reply;
  58. pkt->vsk = info->vsk;
  59. if (info->msg && len > 0) {
  60. pkt->buf = kmalloc(len, GFP_KERNEL);
  61. if (!pkt->buf)
  62. goto out_pkt;
  63. pkt->buf_len = len;
  64. err = memcpy_from_msg(pkt->buf, info->msg, len);
  65. if (err)
  66. goto out;
  67. }
  68. trace_virtio_transport_alloc_pkt(src_cid, src_port,
  69. dst_cid, dst_port,
  70. len,
  71. info->type,
  72. info->op,
  73. info->flags);
  74. return pkt;
  75. out:
  76. kfree(pkt->buf);
  77. out_pkt:
  78. kfree(pkt);
  79. return NULL;
  80. }
  81. /* Packet capture */
  82. static struct sk_buff *virtio_transport_build_skb(void *opaque)
  83. {
  84. struct virtio_vsock_pkt *pkt = opaque;
  85. struct af_vsockmon_hdr *hdr;
  86. struct sk_buff *skb;
  87. size_t payload_len;
  88. void *payload_buf;
  89. /* A packet could be split to fit the RX buffer, so we can retrieve
  90. * the payload length from the header and the buffer pointer taking
  91. * care of the offset in the original packet.
  92. */
  93. payload_len = le32_to_cpu(pkt->hdr.len);
  94. payload_buf = pkt->buf + pkt->off;
  95. skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
  96. GFP_ATOMIC);
  97. if (!skb)
  98. return NULL;
  99. hdr = skb_put(skb, sizeof(*hdr));
  100. /* pkt->hdr is little-endian so no need to byteswap here */
  101. hdr->src_cid = pkt->hdr.src_cid;
  102. hdr->src_port = pkt->hdr.src_port;
  103. hdr->dst_cid = pkt->hdr.dst_cid;
  104. hdr->dst_port = pkt->hdr.dst_port;
  105. hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
  106. hdr->len = cpu_to_le16(sizeof(pkt->hdr));
  107. memset(hdr->reserved, 0, sizeof(hdr->reserved));
  108. switch (le16_to_cpu(pkt->hdr.op)) {
  109. case VIRTIO_VSOCK_OP_REQUEST:
  110. case VIRTIO_VSOCK_OP_RESPONSE:
  111. hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
  112. break;
  113. case VIRTIO_VSOCK_OP_RST:
  114. case VIRTIO_VSOCK_OP_SHUTDOWN:
  115. hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
  116. break;
  117. case VIRTIO_VSOCK_OP_RW:
  118. hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
  119. break;
  120. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  121. case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
  122. hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
  123. break;
  124. default:
  125. hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
  126. break;
  127. }
  128. skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
  129. if (payload_len) {
  130. skb_put_data(skb, payload_buf, payload_len);
  131. }
  132. return skb;
  133. }
  134. void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
  135. {
  136. if (pkt->tap_delivered)
  137. return;
  138. vsock_deliver_tap(virtio_transport_build_skb, pkt);
  139. pkt->tap_delivered = true;
  140. }
  141. EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
  142. /* This function can only be used on connecting/connected sockets,
  143. * since a socket assigned to a transport is required.
  144. *
  145. * Do not use on listener sockets!
  146. */
  147. static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  148. struct virtio_vsock_pkt_info *info)
  149. {
  150. u32 src_cid, src_port, dst_cid, dst_port;
  151. const struct virtio_transport *t_ops;
  152. struct virtio_vsock_sock *vvs;
  153. struct virtio_vsock_pkt *pkt;
  154. u32 pkt_len = info->pkt_len;
  155. t_ops = virtio_transport_get_ops(vsk);
  156. if (unlikely(!t_ops))
  157. return -EFAULT;
  158. src_cid = t_ops->transport.get_local_cid();
  159. src_port = vsk->local_addr.svm_port;
  160. if (!info->remote_cid) {
  161. dst_cid = vsk->remote_addr.svm_cid;
  162. dst_port = vsk->remote_addr.svm_port;
  163. } else {
  164. dst_cid = info->remote_cid;
  165. dst_port = info->remote_port;
  166. }
  167. vvs = vsk->trans;
  168. /* we can send less than pkt_len bytes */
  169. if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
  170. pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
  171. /* virtio_transport_get_credit might return less than pkt_len credit */
  172. pkt_len = virtio_transport_get_credit(vvs, pkt_len);
  173. /* Do not send zero length OP_RW pkt */
  174. if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
  175. return pkt_len;
  176. pkt = virtio_transport_alloc_pkt(info, pkt_len,
  177. src_cid, src_port,
  178. dst_cid, dst_port);
  179. if (!pkt) {
  180. virtio_transport_put_credit(vvs, pkt_len);
  181. return -ENOMEM;
  182. }
  183. virtio_transport_inc_tx_pkt(vvs, pkt);
  184. return t_ops->send_pkt(pkt);
  185. }
  186. static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
  187. struct virtio_vsock_pkt *pkt)
  188. {
  189. if (vvs->rx_bytes + pkt->len > vvs->buf_alloc)
  190. return false;
  191. vvs->rx_bytes += pkt->len;
  192. return true;
  193. }
  194. static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
  195. struct virtio_vsock_pkt *pkt)
  196. {
  197. vvs->rx_bytes -= pkt->len;
  198. vvs->fwd_cnt += pkt->len;
  199. }
  200. void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
  201. {
  202. spin_lock_bh(&vvs->rx_lock);
  203. vvs->last_fwd_cnt = vvs->fwd_cnt;
  204. pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
  205. pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
  206. spin_unlock_bh(&vvs->rx_lock);
  207. }
  208. EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
  209. u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
  210. {
  211. u32 ret;
  212. spin_lock_bh(&vvs->tx_lock);
  213. ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  214. if (ret > credit)
  215. ret = credit;
  216. vvs->tx_cnt += ret;
  217. spin_unlock_bh(&vvs->tx_lock);
  218. return ret;
  219. }
  220. EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
  221. void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
  222. {
  223. spin_lock_bh(&vvs->tx_lock);
  224. vvs->tx_cnt -= credit;
  225. spin_unlock_bh(&vvs->tx_lock);
  226. }
  227. EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
  228. static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
  229. int type,
  230. struct virtio_vsock_hdr *hdr)
  231. {
  232. struct virtio_vsock_pkt_info info = {
  233. .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
  234. .type = type,
  235. .vsk = vsk,
  236. };
  237. return virtio_transport_send_pkt_info(vsk, &info);
  238. }
  239. static ssize_t
  240. virtio_transport_stream_do_peek(struct vsock_sock *vsk,
  241. struct msghdr *msg,
  242. size_t len)
  243. {
  244. struct virtio_vsock_sock *vvs = vsk->trans;
  245. struct virtio_vsock_pkt *pkt;
  246. size_t bytes, total = 0, off;
  247. int err = -EFAULT;
  248. spin_lock_bh(&vvs->rx_lock);
  249. list_for_each_entry(pkt, &vvs->rx_queue, list) {
  250. off = pkt->off;
  251. if (total == len)
  252. break;
  253. while (total < len && off < pkt->len) {
  254. bytes = len - total;
  255. if (bytes > pkt->len - off)
  256. bytes = pkt->len - off;
  257. /* sk_lock is held by caller so no one else can dequeue.
  258. * Unlock rx_lock since memcpy_to_msg() may sleep.
  259. */
  260. spin_unlock_bh(&vvs->rx_lock);
  261. err = memcpy_to_msg(msg, pkt->buf + off, bytes);
  262. if (err)
  263. goto out;
  264. spin_lock_bh(&vvs->rx_lock);
  265. total += bytes;
  266. off += bytes;
  267. }
  268. }
  269. spin_unlock_bh(&vvs->rx_lock);
  270. return total;
  271. out:
  272. if (total)
  273. err = total;
  274. return err;
  275. }
  276. static ssize_t
  277. virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
  278. struct msghdr *msg,
  279. size_t len)
  280. {
  281. struct virtio_vsock_sock *vvs = vsk->trans;
  282. struct virtio_vsock_pkt *pkt;
  283. size_t bytes, total = 0;
  284. u32 free_space;
  285. int err = -EFAULT;
  286. spin_lock_bh(&vvs->rx_lock);
  287. while (total < len && !list_empty(&vvs->rx_queue)) {
  288. pkt = list_first_entry(&vvs->rx_queue,
  289. struct virtio_vsock_pkt, list);
  290. bytes = len - total;
  291. if (bytes > pkt->len - pkt->off)
  292. bytes = pkt->len - pkt->off;
  293. /* sk_lock is held by caller so no one else can dequeue.
  294. * Unlock rx_lock since memcpy_to_msg() may sleep.
  295. */
  296. spin_unlock_bh(&vvs->rx_lock);
  297. err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
  298. if (err)
  299. goto out;
  300. spin_lock_bh(&vvs->rx_lock);
  301. total += bytes;
  302. pkt->off += bytes;
  303. if (pkt->off == pkt->len) {
  304. virtio_transport_dec_rx_pkt(vvs, pkt);
  305. list_del(&pkt->list);
  306. virtio_transport_free_pkt(pkt);
  307. }
  308. }
  309. free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
  310. spin_unlock_bh(&vvs->rx_lock);
  311. /* To reduce the number of credit update messages,
  312. * don't update credits as long as lots of space is available.
  313. * Note: the limit chosen here is arbitrary. Setting the limit
  314. * too high causes extra messages. Too low causes transmitter
  315. * stalls. As stalls are in theory more expensive than extra
  316. * messages, we set the limit to a high value. TODO: experiment
  317. * with different values.
  318. */
  319. if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
  320. virtio_transport_send_credit_update(vsk,
  321. VIRTIO_VSOCK_TYPE_STREAM,
  322. NULL);
  323. }
  324. return total;
  325. out:
  326. if (total)
  327. err = total;
  328. return err;
  329. }
  330. ssize_t
  331. virtio_transport_stream_dequeue(struct vsock_sock *vsk,
  332. struct msghdr *msg,
  333. size_t len, int flags)
  334. {
  335. if (flags & MSG_PEEK)
  336. return virtio_transport_stream_do_peek(vsk, msg, len);
  337. else
  338. return virtio_transport_stream_do_dequeue(vsk, msg, len);
  339. }
  340. EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
  341. int
  342. virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
  343. struct msghdr *msg,
  344. size_t len, int flags)
  345. {
  346. return -EOPNOTSUPP;
  347. }
  348. EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
  349. s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
  350. {
  351. struct virtio_vsock_sock *vvs = vsk->trans;
  352. s64 bytes;
  353. spin_lock_bh(&vvs->rx_lock);
  354. bytes = vvs->rx_bytes;
  355. spin_unlock_bh(&vvs->rx_lock);
  356. return bytes;
  357. }
  358. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
  359. static s64 virtio_transport_has_space(struct vsock_sock *vsk)
  360. {
  361. struct virtio_vsock_sock *vvs = vsk->trans;
  362. s64 bytes;
  363. bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  364. if (bytes < 0)
  365. bytes = 0;
  366. return bytes;
  367. }
  368. s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
  369. {
  370. struct virtio_vsock_sock *vvs = vsk->trans;
  371. s64 bytes;
  372. spin_lock_bh(&vvs->tx_lock);
  373. bytes = virtio_transport_has_space(vsk);
  374. spin_unlock_bh(&vvs->tx_lock);
  375. return bytes;
  376. }
  377. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
  378. int virtio_transport_do_socket_init(struct vsock_sock *vsk,
  379. struct vsock_sock *psk)
  380. {
  381. struct virtio_vsock_sock *vvs;
  382. vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
  383. if (!vvs)
  384. return -ENOMEM;
  385. vsk->trans = vvs;
  386. vvs->vsk = vsk;
  387. if (psk && psk->trans) {
  388. struct virtio_vsock_sock *ptrans = psk->trans;
  389. vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
  390. }
  391. if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
  392. vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
  393. vvs->buf_alloc = vsk->buffer_size;
  394. spin_lock_init(&vvs->rx_lock);
  395. spin_lock_init(&vvs->tx_lock);
  396. INIT_LIST_HEAD(&vvs->rx_queue);
  397. return 0;
  398. }
  399. EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
  400. /* sk_lock held by the caller */
  401. void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
  402. {
  403. struct virtio_vsock_sock *vvs = vsk->trans;
  404. if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  405. *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  406. vvs->buf_alloc = *val;
  407. virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
  408. NULL);
  409. }
  410. EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
  411. int
  412. virtio_transport_notify_poll_in(struct vsock_sock *vsk,
  413. size_t target,
  414. bool *data_ready_now)
  415. {
  416. if (vsock_stream_has_data(vsk))
  417. *data_ready_now = true;
  418. else
  419. *data_ready_now = false;
  420. return 0;
  421. }
  422. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
  423. int
  424. virtio_transport_notify_poll_out(struct vsock_sock *vsk,
  425. size_t target,
  426. bool *space_avail_now)
  427. {
  428. s64 free_space;
  429. free_space = vsock_stream_has_space(vsk);
  430. if (free_space > 0)
  431. *space_avail_now = true;
  432. else if (free_space == 0)
  433. *space_avail_now = false;
  434. return 0;
  435. }
  436. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
  437. int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
  438. size_t target, struct vsock_transport_recv_notify_data *data)
  439. {
  440. return 0;
  441. }
  442. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
  443. int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
  444. size_t target, struct vsock_transport_recv_notify_data *data)
  445. {
  446. return 0;
  447. }
  448. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
  449. int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
  450. size_t target, struct vsock_transport_recv_notify_data *data)
  451. {
  452. return 0;
  453. }
  454. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
  455. int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
  456. size_t target, ssize_t copied, bool data_read,
  457. struct vsock_transport_recv_notify_data *data)
  458. {
  459. return 0;
  460. }
  461. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
  462. int virtio_transport_notify_send_init(struct vsock_sock *vsk,
  463. struct vsock_transport_send_notify_data *data)
  464. {
  465. return 0;
  466. }
  467. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
  468. int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
  469. struct vsock_transport_send_notify_data *data)
  470. {
  471. return 0;
  472. }
  473. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
  474. int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
  475. struct vsock_transport_send_notify_data *data)
  476. {
  477. return 0;
  478. }
  479. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
  480. int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
  481. ssize_t written, struct vsock_transport_send_notify_data *data)
  482. {
  483. return 0;
  484. }
  485. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
  486. u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
  487. {
  488. return vsk->buffer_size;
  489. }
  490. EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
  491. bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
  492. {
  493. return true;
  494. }
  495. EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
  496. bool virtio_transport_stream_allow(u32 cid, u32 port)
  497. {
  498. return true;
  499. }
  500. EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
  501. int virtio_transport_dgram_bind(struct vsock_sock *vsk,
  502. struct sockaddr_vm *addr)
  503. {
  504. return -EOPNOTSUPP;
  505. }
  506. EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
  507. bool virtio_transport_dgram_allow(u32 cid, u32 port)
  508. {
  509. return false;
  510. }
  511. EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
  512. int virtio_transport_connect(struct vsock_sock *vsk)
  513. {
  514. struct virtio_vsock_pkt_info info = {
  515. .op = VIRTIO_VSOCK_OP_REQUEST,
  516. .type = VIRTIO_VSOCK_TYPE_STREAM,
  517. .vsk = vsk,
  518. };
  519. return virtio_transport_send_pkt_info(vsk, &info);
  520. }
  521. EXPORT_SYMBOL_GPL(virtio_transport_connect);
  522. int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
  523. {
  524. struct virtio_vsock_pkt_info info = {
  525. .op = VIRTIO_VSOCK_OP_SHUTDOWN,
  526. .type = VIRTIO_VSOCK_TYPE_STREAM,
  527. .flags = (mode & RCV_SHUTDOWN ?
  528. VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
  529. (mode & SEND_SHUTDOWN ?
  530. VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
  531. .vsk = vsk,
  532. };
  533. return virtio_transport_send_pkt_info(vsk, &info);
  534. }
  535. EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
  536. int
  537. virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
  538. struct sockaddr_vm *remote_addr,
  539. struct msghdr *msg,
  540. size_t dgram_len)
  541. {
  542. return -EOPNOTSUPP;
  543. }
  544. EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
  545. ssize_t
  546. virtio_transport_stream_enqueue(struct vsock_sock *vsk,
  547. struct msghdr *msg,
  548. size_t len)
  549. {
  550. struct virtio_vsock_pkt_info info = {
  551. .op = VIRTIO_VSOCK_OP_RW,
  552. .type = VIRTIO_VSOCK_TYPE_STREAM,
  553. .msg = msg,
  554. .pkt_len = len,
  555. .vsk = vsk,
  556. };
  557. return virtio_transport_send_pkt_info(vsk, &info);
  558. }
  559. EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
  560. void virtio_transport_destruct(struct vsock_sock *vsk)
  561. {
  562. struct virtio_vsock_sock *vvs = vsk->trans;
  563. kfree(vvs);
  564. }
  565. EXPORT_SYMBOL_GPL(virtio_transport_destruct);
  566. static int virtio_transport_reset(struct vsock_sock *vsk,
  567. struct virtio_vsock_pkt *pkt)
  568. {
  569. struct virtio_vsock_pkt_info info = {
  570. .op = VIRTIO_VSOCK_OP_RST,
  571. .type = VIRTIO_VSOCK_TYPE_STREAM,
  572. .reply = !!pkt,
  573. .vsk = vsk,
  574. };
  575. /* Send RST only if the original pkt is not a RST pkt */
  576. if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  577. return 0;
  578. return virtio_transport_send_pkt_info(vsk, &info);
  579. }
  580. /* Normally packets are associated with a socket. There may be no socket if an
  581. * attempt was made to connect to a socket that does not exist.
  582. */
  583. static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
  584. struct virtio_vsock_pkt *pkt)
  585. {
  586. struct virtio_vsock_pkt *reply;
  587. struct virtio_vsock_pkt_info info = {
  588. .op = VIRTIO_VSOCK_OP_RST,
  589. .type = le16_to_cpu(pkt->hdr.type),
  590. .reply = true,
  591. };
  592. /* Send RST only if the original pkt is not a RST pkt */
  593. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  594. return 0;
  595. reply = virtio_transport_alloc_pkt(&info, 0,
  596. le64_to_cpu(pkt->hdr.dst_cid),
  597. le32_to_cpu(pkt->hdr.dst_port),
  598. le64_to_cpu(pkt->hdr.src_cid),
  599. le32_to_cpu(pkt->hdr.src_port));
  600. if (!reply)
  601. return -ENOMEM;
  602. if (!t) {
  603. virtio_transport_free_pkt(reply);
  604. return -ENOTCONN;
  605. }
  606. return t->send_pkt(reply);
  607. }
  608. /* This function should be called with sk_lock held and SOCK_DONE set */
  609. static void virtio_transport_remove_sock(struct vsock_sock *vsk)
  610. {
  611. struct virtio_vsock_sock *vvs = vsk->trans;
  612. struct virtio_vsock_pkt *pkt, *tmp;
  613. /* We don't need to take rx_lock, as the socket is closing and we are
  614. * removing it.
  615. */
  616. list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
  617. list_del(&pkt->list);
  618. virtio_transport_free_pkt(pkt);
  619. }
  620. vsock_remove_sock(vsk);
  621. }
  622. static void virtio_transport_wait_close(struct sock *sk, long timeout)
  623. {
  624. if (timeout) {
  625. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  626. add_wait_queue(sk_sleep(sk), &wait);
  627. do {
  628. if (sk_wait_event(sk, &timeout,
  629. sock_flag(sk, SOCK_DONE), &wait))
  630. break;
  631. } while (!signal_pending(current) && timeout);
  632. remove_wait_queue(sk_sleep(sk), &wait);
  633. }
  634. }
  635. static void virtio_transport_do_close(struct vsock_sock *vsk,
  636. bool cancel_timeout)
  637. {
  638. struct sock *sk = sk_vsock(vsk);
  639. sock_set_flag(sk, SOCK_DONE);
  640. vsk->peer_shutdown = SHUTDOWN_MASK;
  641. if (vsock_stream_has_data(vsk) <= 0)
  642. sk->sk_state = TCP_CLOSING;
  643. sk->sk_state_change(sk);
  644. if (vsk->close_work_scheduled &&
  645. (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
  646. vsk->close_work_scheduled = false;
  647. virtio_transport_remove_sock(vsk);
  648. /* Release refcnt obtained when we scheduled the timeout */
  649. sock_put(sk);
  650. }
  651. }
  652. static void virtio_transport_close_timeout(struct work_struct *work)
  653. {
  654. struct vsock_sock *vsk =
  655. container_of(work, struct vsock_sock, close_work.work);
  656. struct sock *sk = sk_vsock(vsk);
  657. sock_hold(sk);
  658. lock_sock(sk);
  659. if (!sock_flag(sk, SOCK_DONE)) {
  660. (void)virtio_transport_reset(vsk, NULL);
  661. virtio_transport_do_close(vsk, false);
  662. }
  663. vsk->close_work_scheduled = false;
  664. release_sock(sk);
  665. sock_put(sk);
  666. }
  667. /* User context, vsk->sk is locked */
  668. static bool virtio_transport_close(struct vsock_sock *vsk)
  669. {
  670. struct sock *sk = &vsk->sk;
  671. if (!(sk->sk_state == TCP_ESTABLISHED ||
  672. sk->sk_state == TCP_CLOSING))
  673. return true;
  674. /* Already received SHUTDOWN from peer, reply with RST */
  675. if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
  676. (void)virtio_transport_reset(vsk, NULL);
  677. return true;
  678. }
  679. if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
  680. (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
  681. if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
  682. virtio_transport_wait_close(sk, sk->sk_lingertime);
  683. if (sock_flag(sk, SOCK_DONE)) {
  684. return true;
  685. }
  686. sock_hold(sk);
  687. INIT_DELAYED_WORK(&vsk->close_work,
  688. virtio_transport_close_timeout);
  689. vsk->close_work_scheduled = true;
  690. schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
  691. return false;
  692. }
  693. void virtio_transport_release(struct vsock_sock *vsk)
  694. {
  695. struct sock *sk = &vsk->sk;
  696. bool remove_sock = true;
  697. if (sk->sk_type == SOCK_STREAM)
  698. remove_sock = virtio_transport_close(vsk);
  699. if (remove_sock) {
  700. sock_set_flag(sk, SOCK_DONE);
  701. virtio_transport_remove_sock(vsk);
  702. }
  703. }
  704. EXPORT_SYMBOL_GPL(virtio_transport_release);
  705. static int
  706. virtio_transport_recv_connecting(struct sock *sk,
  707. struct virtio_vsock_pkt *pkt)
  708. {
  709. struct vsock_sock *vsk = vsock_sk(sk);
  710. int err;
  711. int skerr;
  712. switch (le16_to_cpu(pkt->hdr.op)) {
  713. case VIRTIO_VSOCK_OP_RESPONSE:
  714. sk->sk_state = TCP_ESTABLISHED;
  715. sk->sk_socket->state = SS_CONNECTED;
  716. vsock_insert_connected(vsk);
  717. sk->sk_state_change(sk);
  718. break;
  719. case VIRTIO_VSOCK_OP_INVALID:
  720. break;
  721. case VIRTIO_VSOCK_OP_RST:
  722. skerr = ECONNRESET;
  723. err = 0;
  724. goto destroy;
  725. default:
  726. skerr = EPROTO;
  727. err = -EINVAL;
  728. goto destroy;
  729. }
  730. return 0;
  731. destroy:
  732. virtio_transport_reset(vsk, pkt);
  733. sk->sk_state = TCP_CLOSE;
  734. sk->sk_err = skerr;
  735. sk->sk_error_report(sk);
  736. return err;
  737. }
  738. static void
  739. virtio_transport_recv_enqueue(struct vsock_sock *vsk,
  740. struct virtio_vsock_pkt *pkt)
  741. {
  742. struct virtio_vsock_sock *vvs = vsk->trans;
  743. bool can_enqueue, free_pkt = false;
  744. pkt->len = le32_to_cpu(pkt->hdr.len);
  745. pkt->off = 0;
  746. spin_lock_bh(&vvs->rx_lock);
  747. can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt);
  748. if (!can_enqueue) {
  749. free_pkt = true;
  750. goto out;
  751. }
  752. /* Try to copy small packets into the buffer of last packet queued,
  753. * to avoid wasting memory queueing the entire buffer with a small
  754. * payload.
  755. */
  756. if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
  757. struct virtio_vsock_pkt *last_pkt;
  758. last_pkt = list_last_entry(&vvs->rx_queue,
  759. struct virtio_vsock_pkt, list);
  760. /* If there is space in the last packet queued, we copy the
  761. * new packet in its buffer.
  762. */
  763. if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
  764. memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
  765. pkt->len);
  766. last_pkt->len += pkt->len;
  767. free_pkt = true;
  768. goto out;
  769. }
  770. }
  771. list_add_tail(&pkt->list, &vvs->rx_queue);
  772. out:
  773. spin_unlock_bh(&vvs->rx_lock);
  774. if (free_pkt)
  775. virtio_transport_free_pkt(pkt);
  776. }
  777. static int
  778. virtio_transport_recv_connected(struct sock *sk,
  779. struct virtio_vsock_pkt *pkt)
  780. {
  781. struct vsock_sock *vsk = vsock_sk(sk);
  782. int err = 0;
  783. switch (le16_to_cpu(pkt->hdr.op)) {
  784. case VIRTIO_VSOCK_OP_RW:
  785. virtio_transport_recv_enqueue(vsk, pkt);
  786. sk->sk_data_ready(sk);
  787. return err;
  788. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  789. sk->sk_write_space(sk);
  790. break;
  791. case VIRTIO_VSOCK_OP_SHUTDOWN:
  792. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
  793. vsk->peer_shutdown |= RCV_SHUTDOWN;
  794. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
  795. vsk->peer_shutdown |= SEND_SHUTDOWN;
  796. if (vsk->peer_shutdown == SHUTDOWN_MASK &&
  797. vsock_stream_has_data(vsk) <= 0 &&
  798. !sock_flag(sk, SOCK_DONE)) {
  799. (void)virtio_transport_reset(vsk, NULL);
  800. virtio_transport_do_close(vsk, true);
  801. }
  802. if (le32_to_cpu(pkt->hdr.flags))
  803. sk->sk_state_change(sk);
  804. break;
  805. case VIRTIO_VSOCK_OP_RST:
  806. virtio_transport_do_close(vsk, true);
  807. break;
  808. default:
  809. err = -EINVAL;
  810. break;
  811. }
  812. virtio_transport_free_pkt(pkt);
  813. return err;
  814. }
  815. static void
  816. virtio_transport_recv_disconnecting(struct sock *sk,
  817. struct virtio_vsock_pkt *pkt)
  818. {
  819. struct vsock_sock *vsk = vsock_sk(sk);
  820. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  821. virtio_transport_do_close(vsk, true);
  822. }
  823. static int
  824. virtio_transport_send_response(struct vsock_sock *vsk,
  825. struct virtio_vsock_pkt *pkt)
  826. {
  827. struct virtio_vsock_pkt_info info = {
  828. .op = VIRTIO_VSOCK_OP_RESPONSE,
  829. .type = VIRTIO_VSOCK_TYPE_STREAM,
  830. .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
  831. .remote_port = le32_to_cpu(pkt->hdr.src_port),
  832. .reply = true,
  833. .vsk = vsk,
  834. };
  835. return virtio_transport_send_pkt_info(vsk, &info);
  836. }
  837. static bool virtio_transport_space_update(struct sock *sk,
  838. struct virtio_vsock_pkt *pkt)
  839. {
  840. struct vsock_sock *vsk = vsock_sk(sk);
  841. struct virtio_vsock_sock *vvs = vsk->trans;
  842. bool space_available;
  843. /* Listener sockets are not associated with any transport, so we are
  844. * not able to take the state to see if there is space available in the
  845. * remote peer, but since they are only used to receive requests, we
  846. * can assume that there is always space available in the other peer.
  847. */
  848. if (!vvs)
  849. return true;
  850. /* buf_alloc and fwd_cnt is always included in the hdr */
  851. spin_lock_bh(&vvs->tx_lock);
  852. vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
  853. vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
  854. space_available = virtio_transport_has_space(vsk);
  855. spin_unlock_bh(&vvs->tx_lock);
  856. return space_available;
  857. }
  858. /* Handle server socket */
  859. static int
  860. virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
  861. struct virtio_transport *t)
  862. {
  863. struct vsock_sock *vsk = vsock_sk(sk);
  864. struct vsock_sock *vchild;
  865. struct sock *child;
  866. int ret;
  867. if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
  868. virtio_transport_reset_no_sock(t, pkt);
  869. return -EINVAL;
  870. }
  871. if (sk_acceptq_is_full(sk)) {
  872. virtio_transport_reset_no_sock(t, pkt);
  873. return -ENOMEM;
  874. }
  875. child = vsock_create_connected(sk);
  876. if (!child) {
  877. virtio_transport_reset_no_sock(t, pkt);
  878. return -ENOMEM;
  879. }
  880. sk_acceptq_added(sk);
  881. lock_sock_nested(child, SINGLE_DEPTH_NESTING);
  882. child->sk_state = TCP_ESTABLISHED;
  883. vchild = vsock_sk(child);
  884. vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
  885. le32_to_cpu(pkt->hdr.dst_port));
  886. vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
  887. le32_to_cpu(pkt->hdr.src_port));
  888. ret = vsock_assign_transport(vchild, vsk);
  889. /* Transport assigned (looking at remote_addr) must be the same
  890. * where we received the request.
  891. */
  892. if (ret || vchild->transport != &t->transport) {
  893. release_sock(child);
  894. virtio_transport_reset_no_sock(t, pkt);
  895. sock_put(child);
  896. return ret;
  897. }
  898. if (virtio_transport_space_update(child, pkt))
  899. child->sk_write_space(child);
  900. vsock_insert_connected(vchild);
  901. vsock_enqueue_accept(sk, child);
  902. virtio_transport_send_response(vchild, pkt);
  903. release_sock(child);
  904. sk->sk_data_ready(sk);
  905. return 0;
  906. }
  907. /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  908. * lock.
  909. */
  910. void virtio_transport_recv_pkt(struct virtio_transport *t,
  911. struct virtio_vsock_pkt *pkt)
  912. {
  913. struct sockaddr_vm src, dst;
  914. struct vsock_sock *vsk;
  915. struct sock *sk;
  916. bool space_available;
  917. vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
  918. le32_to_cpu(pkt->hdr.src_port));
  919. vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
  920. le32_to_cpu(pkt->hdr.dst_port));
  921. trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
  922. dst.svm_cid, dst.svm_port,
  923. le32_to_cpu(pkt->hdr.len),
  924. le16_to_cpu(pkt->hdr.type),
  925. le16_to_cpu(pkt->hdr.op),
  926. le32_to_cpu(pkt->hdr.flags),
  927. le32_to_cpu(pkt->hdr.buf_alloc),
  928. le32_to_cpu(pkt->hdr.fwd_cnt));
  929. if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
  930. (void)virtio_transport_reset_no_sock(t, pkt);
  931. goto free_pkt;
  932. }
  933. /* The socket must be in connected or bound table
  934. * otherwise send reset back
  935. */
  936. sk = vsock_find_connected_socket(&src, &dst);
  937. if (!sk) {
  938. sk = vsock_find_bound_socket(&dst);
  939. if (!sk) {
  940. (void)virtio_transport_reset_no_sock(t, pkt);
  941. goto free_pkt;
  942. }
  943. }
  944. vsk = vsock_sk(sk);
  945. lock_sock(sk);
  946. /* Check if sk has been closed before lock_sock */
  947. if (sock_flag(sk, SOCK_DONE)) {
  948. (void)virtio_transport_reset_no_sock(t, pkt);
  949. release_sock(sk);
  950. sock_put(sk);
  951. goto free_pkt;
  952. }
  953. space_available = virtio_transport_space_update(sk, pkt);
  954. /* Update CID in case it has changed after a transport reset event */
  955. if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
  956. vsk->local_addr.svm_cid = dst.svm_cid;
  957. if (space_available)
  958. sk->sk_write_space(sk);
  959. switch (sk->sk_state) {
  960. case TCP_LISTEN:
  961. virtio_transport_recv_listen(sk, pkt, t);
  962. virtio_transport_free_pkt(pkt);
  963. break;
  964. case TCP_SYN_SENT:
  965. virtio_transport_recv_connecting(sk, pkt);
  966. virtio_transport_free_pkt(pkt);
  967. break;
  968. case TCP_ESTABLISHED:
  969. virtio_transport_recv_connected(sk, pkt);
  970. break;
  971. case TCP_CLOSING:
  972. virtio_transport_recv_disconnecting(sk, pkt);
  973. virtio_transport_free_pkt(pkt);
  974. break;
  975. default:
  976. (void)virtio_transport_reset_no_sock(t, pkt);
  977. virtio_transport_free_pkt(pkt);
  978. break;
  979. }
  980. release_sock(sk);
  981. /* Release refcnt obtained when we fetched this socket out of the
  982. * bound or connected list.
  983. */
  984. sock_put(sk);
  985. return;
  986. free_pkt:
  987. virtio_transport_free_pkt(pkt);
  988. }
  989. EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
  990. void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
  991. {
  992. kfree(pkt->buf);
  993. kfree(pkt);
  994. }
  995. EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
  996. MODULE_LICENSE("GPL v2");
  997. MODULE_AUTHOR("Asias He");
  998. MODULE_DESCRIPTION("common code for virtio vsock");