Staging: hv: vmbus: Properly deal with de-registering channel callback
authorK. Y. Srinivasan <kys@microsoft.com>
Sat, 27 Aug 2011 18:31:33 +0000 (11:31 -0700)
committerGreg Kroah-Hartman <gregkh@suse.de>
Mon, 29 Aug 2011 18:05:30 +0000 (11:05 -0700)
Ensure that we correctly handle racing invocations of the channel callback
when the channel is being closed. We do this using the channel's inbound_lock.
A side-effect of this strategy is that we avoid repeatedly picking up this lock
as we drain the inbound ring-buffer.

Signed-off-by: K. Y. Srinivasan <kys@microsoft.com>
Signed-off-by: Haiyang Zhang <haiyangz@microsoft.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@suse.de>
drivers/staging/hv/channel.c
drivers/staging/hv/connection.c
drivers/staging/hv/netvsc.c
drivers/staging/hv/storvsc_drv.c

index ac92c1f99261e7d1bdaf305478c9961c95d05af7..b6f3d38a6dbb5c8fbf6b424d185dd84e9d01328d 100644 (file)
@@ -513,9 +513,12 @@ void vmbus_close(struct vmbus_channel *channel)
 {
        struct vmbus_channel_close_channel *msg;
        int ret;
+       unsigned long flags;
 
        /* Stop callback and cancel the timer asap */
+       spin_lock_irqsave(&channel->inbound_lock, flags);
        channel->onchannel_callback = NULL;
+       spin_unlock_irqrestore(&channel->inbound_lock, flags);
 
        /* Send a closing message */
 
@@ -735,19 +738,15 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
        u32 packetlen;
        u32 userlen;
        int ret;
-       unsigned long flags;
 
        *buffer_actual_len = 0;
        *requestid = 0;
 
-       spin_lock_irqsave(&channel->inbound_lock, flags);
 
        ret = hv_ringbuffer_peek(&channel->inbound, &desc,
                             sizeof(struct vmpacket_descriptor));
-       if (ret != 0) {
-               spin_unlock_irqrestore(&channel->inbound_lock, flags);
+       if (ret != 0)
                return 0;
-       }
 
        packetlen = desc.len8 << 3;
        userlen = packetlen - (desc.offset8 << 3);
@@ -755,7 +754,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
        *buffer_actual_len = userlen;
 
        if (userlen > bufferlen) {
-               spin_unlock_irqrestore(&channel->inbound_lock, flags);
 
                pr_err("Buffer too small - got %d needs %d\n",
                           bufferlen, userlen);
@@ -768,7 +766,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
        ret = hv_ringbuffer_read(&channel->inbound, buffer, userlen,
                             (desc.offset8 << 3));
 
-       spin_unlock_irqrestore(&channel->inbound_lock, flags);
 
        return 0;
 }
@@ -785,19 +782,15 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
        u32 packetlen;
        u32 userlen;
        int ret;
-       unsigned long flags;
 
        *buffer_actual_len = 0;
        *requestid = 0;
 
-       spin_lock_irqsave(&channel->inbound_lock, flags);
 
        ret = hv_ringbuffer_peek(&channel->inbound, &desc,
                             sizeof(struct vmpacket_descriptor));
-       if (ret != 0) {
-               spin_unlock_irqrestore(&channel->inbound_lock, flags);
+       if (ret != 0)
                return 0;
-       }
 
 
        packetlen = desc.len8 << 3;
@@ -806,8 +799,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
        *buffer_actual_len = packetlen;
 
        if (packetlen > bufferlen) {
-               spin_unlock_irqrestore(&channel->inbound_lock, flags);
-
                pr_err("Buffer too small - needed %d bytes but "
                        "got space for only %d bytes\n",
                        packetlen, bufferlen);
@@ -819,7 +810,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
        /* Copy over the entire packet to the user buffer */
        ret = hv_ringbuffer_read(&channel->inbound, buffer, packetlen, 0);
 
-       spin_unlock_irqrestore(&channel->inbound_lock, flags);
        return 0;
 }
 EXPORT_SYMBOL_GPL(vmbus_recvpacket_raw);
index 7a3ec75bab18e3715eb6f30ac3878f847fc0fe63..6aab802ba99afd81753c4ff2f97555bd8227f207 100644 (file)
@@ -215,6 +215,7 @@ struct vmbus_channel *relid2channel(u32 relid)
 static void process_chn_event(u32 relid)
 {
        struct vmbus_channel *channel;
+       unsigned long flags;
 
        /*
         * Find the channel based on this relid and invokes the
@@ -222,11 +223,13 @@ static void process_chn_event(u32 relid)
         */
        channel = relid2channel(relid);
 
+       spin_lock_irqsave(&channel->inbound_lock, flags);
        if (channel && (channel->onchannel_callback != NULL)) {
                channel->onchannel_callback(channel->channel_callback_context);
        } else {
                pr_err("channel not found for relid - %u\n", relid);
        }
+       spin_unlock_irqrestore(&channel->inbound_lock, flags);
 }
 
 /*
index 9828f0b5960e4f5d806d394e79f4443698f7a7fb..e4cc40a3bd8659150970dea5afd3e46dee5e0288 100644 (file)
@@ -62,9 +62,7 @@ static struct netvsc_device *get_outbound_net_device(struct hv_device *device)
 static struct netvsc_device *get_inbound_net_device(struct hv_device *device)
 {
        struct netvsc_device *net_device;
-       unsigned long flags;
 
-       spin_lock_irqsave(&device->channel->inbound_lock, flags);
        net_device = device->ext;
 
        if (!net_device)
@@ -75,7 +73,6 @@ static struct netvsc_device *get_inbound_net_device(struct hv_device *device)
                net_device = NULL;
 
 get_in_err:
-       spin_unlock_irqrestore(&device->channel->inbound_lock, flags);
        return net_device;
 }
 
index d575bc92ffc4606caa494d2babc434e0eef73b12..3686d1048e313af18b6b8b40358b2d6621baf383 100644 (file)
@@ -352,9 +352,7 @@ static inline struct storvsc_device *get_in_stor_device(
                                        struct hv_device *device)
 {
        struct storvsc_device *stor_device;
-       unsigned long flags;
 
-       spin_lock_irqsave(&device->channel->inbound_lock, flags);
        stor_device = (struct storvsc_device *)device->ext;
 
        if (!stor_device)
@@ -370,7 +368,6 @@ static inline struct storvsc_device *get_in_stor_device(
                stor_device = NULL;
 
 get_in_err:
-       spin_unlock_irqrestore(&device->channel->inbound_lock, flags);
        return stor_device;
 
 }