summaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/hyperv_transport.c
blob: b3bdae74c2435e2445a8badd7f6d3d3561be8908 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Hyper-V transport for vsock
 *
 * Hyper-V Sockets supplies a byte-stream based communication mechanism
 * between the host and the VM. This driver implements the necessary
 * support in the VM by introducing the new vsock transport.
 *
 * Copyright (c) 2017, Microsoft Corporation.
 */
#include <linux/module.h>
#include <linux/vmalloc.h>
#include <linux/hyperv.h>
#include <net/sock.h>
#include <net/af_vsock.h>
#include <asm/hyperv-tlfs.h>

/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some
 * stricter requirements on the hv_sock ring buffer size of six 4K pages.
 * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this
 * limitation; but, keep the defaults the same for compat.
 */
#define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6)
#define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6)
#define RINGBUFFER_HVS_MAX_SIZE (HV_HYP_PAGE_SIZE * 64)

/* The MTU is 16KB per the host side's design */
#define HVS_MTU_SIZE		(1024 * 16)

/* How long to wait for graceful shutdown of a connection */
#define HVS_CLOSE_TIMEOUT (8 * HZ)

struct vmpipe_proto_header {
	u32 pkt_type;
	u32 data_size;
};

/* For recv, we use the VMBus in-place packet iterator APIs to directly copy
 * data from the ringbuffer into the userspace buffer.
 */
struct hvs_recv_buf {
	/* The header before the payload data */
	struct vmpipe_proto_header hdr;

	/* The payload */
	u8 data[HVS_MTU_SIZE];
};

/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use
 * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the
 * guest and the host processing as one VMBUS packet is the smallest processing
 * unit.
 *
 * Note: the buffer can be eliminated in the future when we add new VMBus
 * ringbuffer APIs that allow us to directly copy data from userspace buffer
 * to VMBus ringbuffer.
 */
#define HVS_SEND_BUF_SIZE \
		(HV_HYP_PAGE_SIZE - sizeof(struct vmpipe_proto_header))

struct hvs_send_buf {
	/* The header before the payload data */
	struct vmpipe_proto_header hdr;

	/* The payload */
	u8 data[HVS_SEND_BUF_SIZE];
};

#define HVS_HEADER_LEN	(sizeof(struct vmpacket_descriptor) + \
			 sizeof(struct vmpipe_proto_header))

/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and
 * __hv_pkt_iter_next().
 */
#define VMBUS_PKT_TRAILER_SIZE	(sizeof(u64))

#define HVS_PKT_LEN(payload_len)	(HVS_HEADER_LEN + \
					 ALIGN((payload_len), 8) + \
					 VMBUS_PKT_TRAILER_SIZE)

union hvs_service_id {
	guid_t	srv_id;

	struct {
		unsigned int svm_port;
		unsigned char b[sizeof(guid_t) - sizeof(unsigned int)];
	};
};

/* Per-socket state (accessed via vsk->trans) */
struct hvsock {
	struct vsock_sock *vsk;

	guid_t vm_srv_id;
	guid_t host_srv_id;

	struct vmbus_channel *chan;
	struct vmpacket_descriptor *recv_desc;

	/* The length of the payload not delivered to userland yet */
	u32 recv_data_len;
	/* The offset of the payload */
	u32 recv_data_off;

	/* Have we sent the zero-length packet (FIN)? */
	bool fin_sent;
};

/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is
 * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here:
 * when we write apps to connect to the host, we can only use VMADDR_CID_ANY
 * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we
 * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY
 * as the local cid.
 *
 * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
 * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
 * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
 * the below sockaddr:
 *
 * struct SOCKADDR_HV
 * {
 *    ADDRESS_FAMILY Family;
 *    USHORT Reserved;
 *    GUID VmId;
 *    GUID ServiceId;
 * };
 * Note: VmID is not used by Linux VM and actually it isn't transmitted via
 * VMBus, because here it's obvious the host and the VM can easily identify
 * each other. Though the VmID is useful on the host, especially in the case
 * of Windows container, Linux VM doesn't need it at all.
 *
 * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit
 * the available GUID space of SOCKADDR_HV so that we can create a mapping
 * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing
 * Hyper-V Sockets apps on the host and in Linux VM is:
 *
 ****************************************************************************
 * The only valid Service GUIDs, from the perspectives of both the host and *
 * Linux VM, that can be connected by the other end, must conform to this   *
 * format: <port>-facb-11e6-bd58-64006a7986d3, and the "port" must be in    *
 * this range [0, 0x7FFFFFFF].                                              *
 ****************************************************************************
 *
 * When we write apps on the host to connect(), the GUID ServiceID is used.
 * When we write apps in Linux VM to connect(), we only need to specify the
 * port and the driver will form the GUID and use that to request the host.
 *
 * From the perspective of Linux VM:
 * 1. the local ephemeral port (i.e. the local auto-bound port when we call
 * connect() without explicit bind()) is generated by __vsock_bind_stream(),
 * and the range is [1024, 0xFFFFFFFF).
 * 2. the remote ephemeral port (i.e. the auto-generated remote port for
 * a connect request initiated by the host's connect()) is generated by
 * hvs_remote_addr_init() and the range is [0x80000000, 0xFFFFFFFF).
 */

#define MAX_LISTEN_PORT			((u32)0x7FFFFFFF)
#define MAX_VM_LISTEN_PORT		MAX_LISTEN_PORT
#define MAX_HOST_LISTEN_PORT		MAX_LISTEN_PORT
#define MIN_HOST_EPHEMERAL_PORT		(MAX_HOST_LISTEN_PORT + 1)

/* 00000000-facb-11e6-bd58-64006a7986d3 */
static const guid_t srv_id_template =
	GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
		  0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);

static bool hvs_check_transport(struct vsock_sock *vsk);

static bool is_valid_srv_id(const guid_t *id)
{
	return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
}

static unsigned int get_port_by_srv_id(const guid_t *svr_id)
{
	return *((unsigned int *)svr_id);
}

static void hvs_addr_init(struct sockaddr_vm *addr, const guid_t *svr_id)
{
	unsigned int port = get_port_by_srv_id(svr_id);

	vsock_addr_init(addr, VMADDR_CID_ANY, port);
}

static void hvs_remote_addr_init(struct sockaddr_vm *remote,
				 struct sockaddr_vm *local)
{
	static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
	struct sock *sk;

	/* Remote peer is always the host */
	vsock_addr_init(remote, VMADDR_CID_HOST, VMADDR_PORT_ANY);

	while (1) {
		/* Wrap around ? */
		if (host_ephemeral_port < MIN_HOST_EPHEMERAL_PORT ||
		    host_ephemeral_port == VMADDR_PORT_ANY)
			host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;

		remote->svm_port = host_ephemeral_port++;

		sk = vsock_find_connected_socket(remote, local);
		if (!sk) {
			/* Found an available ephemeral port */
			return;
		}

		/* Release refcnt got in vsock_find_connected_socket */
		sock_put(sk);
	}
}

static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan)
{
	set_channel_pending_send_size(chan,
				      HVS_PKT_LEN(HVS_SEND_BUF_SIZE));

	virt_mb();
}

static bool hvs_channel_readable(struct vmbus_channel *chan)
{
	u32 readable = hv_get_bytes_to_read(&chan->inbound);

	/* 0-size payload means FIN */
	return readable >= HVS_PKT_LEN(0);
}

static int hvs_channel_readable_payload(struct vmbus_channel *chan)
{
	u32 readable = hv_get_bytes_to_read(&chan->inbound);

	if (readable > HVS_PKT_LEN(0)) {
		/* At least we have 1 byte to read. We don't need to return
		 * the exact readable bytes: see vsock_stream_recvmsg() ->
		 * vsock_stream_has_data().
		 */
		return 1;
	}

	if (readable == HVS_PKT_LEN(0)) {
		/* 0-size payload means FIN */
		return 0;
	}

	/* No payload or FIN */
	return -1;
}

static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan)
{
	u32 writeable = hv_get_bytes_to_write(&chan->outbound);
	size_t ret;

	/* The ringbuffer mustn't be 100% full, and we should reserve a
	 * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
	 * and hvs_shutdown().
	 */
	if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0))
		return 0;

	ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0);

	return round_down(ret, 8);
}

static int hvs_send_data(struct vmbus_channel *chan,
			 struct hvs_send_buf *send_buf, size_t to_write)
{
	send_buf->hdr.pkt_type = 1;
	send_buf->hdr.data_size = to_write;
	return vmbus_sendpacket(chan, &send_buf->hdr,
				sizeof(send_buf->hdr) + to_write,
				0, VM_PKT_DATA_INBAND, 0);
}

static void hvs_channel_cb(void *ctx)
{
	struct sock *sk = (struct sock *)ctx;
	struct vsock_sock *vsk = vsock_sk(sk);
	struct hvsock *hvs = vsk->trans;
	struct vmbus_channel *chan = hvs->chan;

	if (hvs_channel_readable(chan))
		sk->sk_data_ready(sk);

	if (hv_get_bytes_to_write(&chan->outbound) > 0)
		sk->sk_write_space(sk);
}

static void hvs_do_close_lock_held(struct vsock_sock *vsk,
				   bool cancel_timeout)
{
	struct sock *sk = sk_vsock(vsk);

	sock_set_flag(sk, SOCK_DONE);
	vsk->peer_shutdown = SHUTDOWN_MASK;
	if (vsock_stream_has_data(vsk) <= 0)
		sk->sk_state = TCP_CLOSING;
	sk->sk_state_change(sk);
	if (vsk->close_work_scheduled &&
	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
		vsk->close_work_scheduled = false;
		vsock_remove_sock(vsk);

		/* Release the reference taken while scheduling the timeout */
		sock_put(sk);
	}
}

static void hvs_close_connection(struct vmbus_channel *chan)
{
	struct sock *sk = get_per_channel_state(chan);

	lock_sock(sk);
	hvs_do_close_lock_held(vsock_sk(sk), true);
	release_sock(sk);

	/* Release the refcnt for the channel that's opened in
	 * hvs_open_connection().
	 */
	sock_put(sk);
}

static void hvs_open_connection(struct vmbus_channel *chan)
{
	guid_t *if_instance, *if_type;
	unsigned char conn_from_host;

	struct sockaddr_vm addr;
	struct sock *sk, *new = NULL;
	struct vsock_sock *vnew = NULL;
	struct hvsock *hvs = NULL;
	struct hvsock *hvs_new = NULL;
	int rcvbuf;
	int ret;
	int sndbuf;

	if_type = &chan->offermsg.offer.if_type;
	if_instance = &chan->offermsg.offer.if_instance;
	conn_from_host = chan->offermsg.offer.u.pipe.user_def[0];

	/* The host or the VM should only listen on a port in
	 * [0, MAX_LISTEN_PORT]
	 */
	if (!is_valid_srv_id(if_type) ||
	    get_port_by_srv_id(if_type) > MAX_LISTEN_PORT)
		return;

	hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
	sk = vsock_find_bound_socket(&addr);
	if (!sk)
		return;

	lock_sock(sk);
	if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
	    (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
		goto out;

	if (conn_from_host) {
		if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
			goto out;

		new = vsock_create_connected(sk);
		if (!new)
			goto out;

		new->sk_state = TCP_SYN_SENT;
		vnew = vsock_sk(new);

		hvs_addr_init(&vnew->local_addr, if_type);
		hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);

		ret = vsock_assign_transport(vnew, vsock_sk(sk));
		/* Transport assigned (looking at remote_addr) must be the
		 * same where we received the request.
		 */
		if (ret || !hvs_check_transport(vnew)) {
			sock_put(new);
			goto out;
		}
		hvs_new = vnew->trans;
		hvs_new->chan = chan;
	} else {
		hvs = vsock_sk(sk)->trans;
		hvs->chan = chan;
	}

	set_channel_read_mode(chan, HV_CALL_DIRECT);

	/* Use the socket buffer sizes as hints for the VMBUS ring size. For
	 * server side sockets, 'sk' is the parent socket and thus, this will
	 * allow the child sockets to inherit the size from the parent. Keep
	 * the mins to the default value and align to page size as per VMBUS
	 * requirements.
	 * For the max, the socket core library will limit the socket buffer
	 * size that can be set by the user, but, since currently, the hv_sock
	 * VMBUS ring buffer is physically contiguous allocation, restrict it
	 * further.
	 * Older versions of hv_sock host side code cannot handle bigger VMBUS
	 * ring buffer size. Use the version number to limit the change to newer
	 * versions.
	 */
	if (vmbus_proto_version < VERSION_WIN10_V5) {
		sndbuf = RINGBUFFER_HVS_SND_SIZE;
		rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
	} else {
		sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
		sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
		sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE);
		rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
		rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
		rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE);
	}

	ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb,
			 conn_from_host ? new : sk);
	if (ret != 0) {
		if (conn_from_host) {
			hvs_new->chan = NULL;
			sock_put(new);
		} else {
			hvs->chan = NULL;
		}
		goto out;
	}

	set_per_channel_state(chan, conn_from_host ? new : sk);

	/* This reference will be dropped by hvs_close_connection(). */
	sock_hold(conn_from_host ? new : sk);
	vmbus_set_chn_rescind_callback(chan, hvs_close_connection);

	/* Set the pending send size to max packet size to always get
	 * notifications from the host when there is enough writable space.
	 * The host is optimized to send notifications only when the pending
	 * size boundary is crossed, and not always.
	 */
	hvs_set_channel_pending_send_size(chan);

	if (conn_from_host) {
		new->sk_state = TCP_ESTABLISHED;
		sk_acceptq_added(sk);

		hvs_new->vm_srv_id = *if_type;
		hvs_new->host_srv_id = *if_instance;

		vsock_insert_connected(vnew);

		vsock_enqueue_accept(sk, new);
	} else {
		sk->sk_state = TCP_ESTABLISHED;
		sk->sk_socket->state = SS_CONNECTED;

		vsock_insert_connected(vsock_sk(sk));
	}

	sk->sk_state_change(sk);

out:
	/* Release refcnt obtained when we called vsock_find_bound_socket() */
	sock_put(sk);

	release_sock(sk);
}

static u32 hvs_get_local_cid(void)
{
	return VMADDR_CID_ANY;
}

static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk)
{
	struct hvsock *hvs;
	struct sock *sk = sk_vsock(vsk);

	hvs = kzalloc(sizeof(*hvs), GFP_KERNEL);
	if (!hvs)
		return -ENOMEM;

	vsk->trans = hvs;
	hvs->vsk = vsk;
	sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE;
	sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
	return 0;
}

static int hvs_connect(struct vsock_sock *vsk)
{
	union hvs_service_id vm, host;
	struct hvsock *h = vsk->trans;

	vm.srv_id = srv_id_template;
	vm.svm_port = vsk->local_addr.svm_port;
	h->vm_srv_id = vm.srv_id;

	host.srv_id = srv_id_template;
	host.svm_port = vsk->remote_addr.svm_port;
	h->host_srv_id = host.srv_id;

	return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
}

static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode)
{
	struct vmpipe_proto_header hdr;

	if (hvs->fin_sent || !hvs->chan)
		return;

	/* It can't fail: see hvs_channel_writable_bytes(). */
	(void)hvs_send_data(hvs->chan, (struct hvs_send_buf *)&hdr, 0);
	hvs->fin_sent = true;
}

static int hvs_shutdown(struct vsock_sock *vsk, int mode)
{
	struct sock *sk = sk_vsock(vsk);

	if (!(mode & SEND_SHUTDOWN))
		return 0;

	lock_sock(sk);
	hvs_shutdown_lock_held(vsk->trans, mode);
	release_sock(sk);
	return 0;
}

static void hvs_close_timeout(struct work_struct *work)
{
	struct vsock_sock *vsk =
		container_of(work, struct vsock_sock, close_work.work);
	struct sock *sk = sk_vsock(vsk);

	sock_hold(sk);
	lock_sock(sk);
	if (!sock_flag(sk, SOCK_DONE))
		hvs_do_close_lock_held(vsk, false);

	vsk->close_work_scheduled = false;
	release_sock(sk);
	sock_put(sk);
}

/* Returns true, if it is safe to remove socket; false otherwise */
static bool hvs_close_lock_held(struct vsock_sock *vsk)
{
	struct sock *sk = sk_vsock(vsk);

	if (!(sk->sk_state == TCP_ESTABLISHED ||
	      sk->sk_state == TCP_CLOSING))
		return true;

	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
		hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK);

	if (sock_flag(sk, SOCK_DONE))
		return true;

	/* This reference will be dropped by the delayed close routine */
	sock_hold(sk);
	INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout);
	vsk->close_work_scheduled = true;
	schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT);
	return false;
}

static void hvs_release(struct vsock_sock *vsk)
{
	struct sock *sk = sk_vsock(vsk);
	bool remove_sock;

	lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
	remove_sock = hvs_close_lock_held(vsk);
	release_sock(sk);
	if (remove_sock)
		vsock_remove_sock(vsk);
}

static void hvs_destruct(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;
	struct vmbus_channel *chan = hvs->chan;

	if (chan)
		vmbus_hvsock_device_unregister(chan);

	kfree(hvs);
}

static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
{
	return -EOPNOTSUPP;
}

static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
			     size_t len, int flags)
{
	return -EOPNOTSUPP;
}

static int hvs_dgram_enqueue(struct vsock_sock *vsk,
			     struct sockaddr_vm *remote, struct msghdr *msg,
			     size_t dgram_len)
{
	return -EOPNOTSUPP;
}

static bool hvs_dgram_allow(u32 cid, u32 port)
{
	return false;
}

static int hvs_update_recv_data(struct hvsock *hvs)
{
	struct hvs_recv_buf *recv_buf;
	u32 payload_len;

	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
	payload_len = recv_buf->hdr.data_size;

	if (payload_len > HVS_MTU_SIZE)
		return -EIO;

	if (payload_len == 0)
		hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;

	hvs->recv_data_len = payload_len;
	hvs->recv_data_off = 0;

	return 0;
}

static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
				  size_t len, int flags)
{
	struct hvsock *hvs = vsk->trans;
	bool need_refill = !hvs->recv_desc;
	struct hvs_recv_buf *recv_buf;
	u32 to_read;
	int ret;

	if (flags & MSG_PEEK)
		return -EOPNOTSUPP;

	if (need_refill) {
		hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
		ret = hvs_update_recv_data(hvs);
		if (ret)
			return ret;
	}

	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
	to_read = min_t(u32, len, hvs->recv_data_len);
	ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read);
	if (ret != 0)
		return ret;

	hvs->recv_data_len -= to_read;
	if (hvs->recv_data_len == 0) {
		hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc);
		if (hvs->recv_desc) {
			ret = hvs_update_recv_data(hvs);
			if (ret)
				return ret;
		}
	} else {
		hvs->recv_data_off += to_read;
	}

	return to_read;
}

static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
				  size_t len)
{
	struct hvsock *hvs = vsk->trans;
	struct vmbus_channel *chan = hvs->chan;
	struct hvs_send_buf *send_buf;
	ssize_t to_write, max_writable;
	ssize_t ret = 0;
	ssize_t bytes_written = 0;

	BUILD_BUG_ON(sizeof(*send_buf) != HV_HYP_PAGE_SIZE);

	send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
	if (!send_buf)
		return -ENOMEM;

	/* Reader(s) could be draining data from the channel as we write.
	 * Maximize bandwidth, by iterating until the channel is found to be
	 * full.
	 */
	while (len) {
		max_writable = hvs_channel_writable_bytes(chan);
		if (!max_writable)
			break;
		to_write = min_t(ssize_t, len, max_writable);
		to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE);
		/* memcpy_from_msg is safe for loop as it advances the offsets
		 * within the message iterator.
		 */
		ret = memcpy_from_msg(send_buf->data, msg, to_write);
		if (ret < 0)
			goto out;

		ret = hvs_send_data(hvs->chan, send_buf, to_write);
		if (ret < 0)
			goto out;

		bytes_written += to_write;
		len -= to_write;
	}
out:
	/* If any data has been sent, return that */
	if (bytes_written)
		ret = bytes_written;
	kfree(send_buf);
	return ret;
}

static s64 hvs_stream_has_data(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;
	s64 ret;

	if (hvs->recv_data_len > 0)
		return 1;

	switch (hvs_channel_readable_payload(hvs->chan)) {
	case 1:
		ret = 1;
		break;
	case 0:
		vsk->peer_shutdown |= SEND_SHUTDOWN;
		ret = 0;
		break;
	default: /* -1 */
		ret = 0;
		break;
	}

	return ret;
}

static s64 hvs_stream_has_space(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;

	return hvs_channel_writable_bytes(hvs->chan);
}

static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk)
{
	return HVS_MTU_SIZE + 1;
}

static bool hvs_stream_is_active(struct vsock_sock *vsk)
{
	struct hvsock *hvs = vsk->trans;

	return hvs->chan != NULL;
}

static bool hvs_stream_allow(u32 cid, u32 port)
{
	/* The host's port range [MIN_HOST_EPHEMERAL_PORT, 0xFFFFFFFF) is
	 * reserved as ephemeral ports, which are used as the host's ports
	 * when the host initiates connections.
	 *
	 * Perform this check in the guest so an immediate error is produced
	 * instead of a timeout.
	 */
	if (port > MAX_HOST_LISTEN_PORT)
		return false;

	if (cid == VMADDR_CID_HOST)
		return true;

	return false;
}

static
int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable)
{
	struct hvsock *hvs = vsk->trans;

	*readable = hvs_channel_readable(hvs->chan);
	return 0;
}

static
int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable)
{
	*writable = hvs_stream_has_space(vsk) > 0;

	return 0;
}

static
int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target,
			 struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target,
			      struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target,
				struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target,
				 ssize_t copied, bool data_read,
				 struct vsock_transport_recv_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_init(struct vsock_sock *vsk,
			 struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_pre_block(struct vsock_sock *vsk,
			      struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk,
				struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static
int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
				 struct vsock_transport_send_notify_data *d)
{
	return 0;
}

static struct vsock_transport hvs_transport = {
	.module                   = THIS_MODULE,

	.get_local_cid            = hvs_get_local_cid,

	.init                     = hvs_sock_init,
	.destruct                 = hvs_destruct,
	.release                  = hvs_release,
	.connect                  = hvs_connect,
	.shutdown                 = hvs_shutdown,

	.dgram_bind               = hvs_dgram_bind,
	.dgram_dequeue            = hvs_dgram_dequeue,
	.dgram_enqueue            = hvs_dgram_enqueue,
	.dgram_allow              = hvs_dgram_allow,

	.stream_dequeue           = hvs_stream_dequeue,
	.stream_enqueue           = hvs_stream_enqueue,
	.stream_has_data          = hvs_stream_has_data,
	.stream_has_space         = hvs_stream_has_space,
	.stream_rcvhiwat          = hvs_stream_rcvhiwat,
	.stream_is_active         = hvs_stream_is_active,
	.stream_allow             = hvs_stream_allow,

	.notify_poll_in           = hvs_notify_poll_in,
	.notify_poll_out          = hvs_notify_poll_out,
	.notify_recv_init         = hvs_notify_recv_init,
	.notify_recv_pre_block    = hvs_notify_recv_pre_block,
	.notify_recv_pre_dequeue  = hvs_notify_recv_pre_dequeue,
	.notify_recv_post_dequeue = hvs_notify_recv_post_dequeue,
	.notify_send_init         = hvs_notify_send_init,
	.notify_send_pre_block    = hvs_notify_send_pre_block,
	.notify_send_pre_enqueue  = hvs_notify_send_pre_enqueue,
	.notify_send_post_enqueue = hvs_notify_send_post_enqueue,

};

static bool hvs_check_transport(struct vsock_sock *vsk)
{
	return vsk->transport == &hvs_transport;
}

static int hvs_probe(struct hv_device *hdev,
		     const struct hv_vmbus_device_id *dev_id)
{
	struct vmbus_channel *chan = hdev->channel;

	hvs_open_connection(chan);

	/* Always return success to suppress the unnecessary error message
	 * in vmbus_probe(): on error the host will rescind the device in
	 * 30 seconds and we can do cleanup at that time in
	 * vmbus_onoffer_rescind().
	 */
	return 0;
}

static int hvs_remove(struct hv_device *hdev)
{
	struct vmbus_channel *chan = hdev->channel;

	vmbus_close(chan);

	return 0;
}

/* hv_sock connections can not persist across hibernation, and all the hv_sock
 * channels are forced to be rescinded before hibernation: see
 * vmbus_bus_suspend(). Here the dummy hvs_suspend() and hvs_resume()
 * are only needed because hibernation requires that every vmbus device's
 * driver should have a .suspend and .resume callback: see vmbus_suspend().
 */
static int hvs_suspend(struct hv_device *hv_dev)
{
	/* Dummy */
	return 0;
}

static int hvs_resume(struct hv_device *dev)
{
	/* Dummy */
	return 0;
}

/* This isn't really used. See vmbus_match() and vmbus_probe() */
static const struct hv_vmbus_device_id id_table[] = {
	{},
};

static struct hv_driver hvs_drv = {
	.name		= "hv_sock",
	.hvsock		= true,
	.id_table	= id_table,
	.probe		= hvs_probe,
	.remove		= hvs_remove,
	.suspend	= hvs_suspend,
	.resume		= hvs_resume,
};

static int __init hvs_init(void)
{
	int ret;

	if (vmbus_proto_version < VERSION_WIN10)
		return -ENODEV;

	ret = vmbus_driver_register(&hvs_drv);
	if (ret != 0)
		return ret;

	ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H);
	if (ret) {
		vmbus_driver_unregister(&hvs_drv);
		return ret;
	}

	return 0;
}

static void __exit hvs_exit(void)
{
	vsock_core_unregister(&hvs_transport);
	vmbus_driver_unregister(&hvs_drv);
}

module_init(hvs_init);
module_exit(hvs_exit);

MODULE_DESCRIPTION("Hyper-V Sockets");
MODULE_VERSION("1.0.0");
MODULE_LICENSE("GPL");
MODULE_ALIAS_NETPROTO(PF_VSOCK);