Home

Semi-convolutional U-Net

In [ ]:
import tensorflow as tf
In [ ]:
def semi_conv_2d(inputs, **conv_kwargs):
    with tf.variable_scope('semi_conv_2d'):
        inputs = tf.layers.conv2d(inputs, kernel_size=1, **conv_kwargs)
        shape = tf.shape(inputs)
        inds = [tf.arange(shape[i]) for i in tf.range(shape)]
        delta = tf.meshgrid(*inds)/shape
        inputs = inputs + delta
    return inputs
In [ ]:
a = tf.placeholder(dtype=tf.int32, shape=[])
b = tf.one_hot([[1,9,1,0]], depth=a, axis=0)
with tf.Session() as sess:
    print(sess.run(b, {a:5}).shape)
In [ ]:
import numpy as np
a = np.random.random((7,4,5,6)) > 0.5
b = np.random.random((9,4,5,6))
(a[np.newaxis]*b[:,np.newaxis]).shape

Loss function

  • The loss is defined for the set of instances $\mathcal{S}$ within an image but could also be extended to the set of instances within a mini-batch

    $$\mathcal{L}(\Psi|\mathbf{x}, \mathcal{S}) = \sum_{S \in \mathcal{S}} \frac{1}{\lvert S \rvert} \sum_{u \in S}\left\lVert\Psi_u(\mathbf{x}) - \frac{1}{\lvert S \rvert}\sum_{u \in S}\Psi_u(\mathbf{x})\right\rVert$$

  • Specifically for each instance $S$, the loss is the mean of the Euclidean distances between the embeddings for each pixel $u \in S$ and the mean embedding over all the pixels in that instance.

  • Notice that it only encourages embeddings within each instance to be similar and does not explicitly discourage embeddings in different instances to be different.

Implementing the loss for a mini-batch

  • Now we will implement the loss for a mini-batch with an arbitrary maximum number of instances in any image in the batch
  • The instances in an image $\mathbf{x}$ are labelled $0,...,|\mathcal{S}|-1$ for $\mathcal{S}$ instances.
In [ ]:
def semi_conv_loss(y_true, y_pred):
    """
    Implements equation 5 from https://arxiv.org/abs/1807.10712 for a mini-batch of images.
    
    Args:
        y_true (Tensor): sparse label tensor of shape batch_size x height x width, 
                         with a separate number for each instance present in the image.
                         Requires that the values are consecutive integers starting from 0.
        y_pred (Tensor): sparse prediction tensor of shape batch_size x height x width x channels
        
    Returns:
        semi-convolutional loss 
    """
  • In simple multi-class segmentation where we don't keep instances of the same class separate we can group together
  • But here it is important to keep the instances separate across batches since instance $i$ in one image is not necessarily from the same class as instance $i$ from another image.
  • We obtain a one-hot encoded label map where the depth is the maximum number of instances in any of the images in the batch.
  • Then for each instance we use the one-hot map to mask all the embeddings which don't belong pixels in that instance.
In [ ]:
    #find the maximum number of instances in any image in this batch
    n_inst_max = tf.max(y_true) 
    
    #batch_size x height x width -> n_inst_max x batch_size x height x width
    y_true_one_hot = tf.one_hot(y_true, depth=n_inst_max, axis=0)
    
    #results in tensor of shape batch_size x n_inst_max x height x width x channels
    y_pred_dense = y_true_one_hot[tf.newaxis]*y_pred[:,tf.newaxis]
    
    #reshape to (batch_size*n_inst_max) x height x width x channels
    y_pred_dense = tf.reshape(y_pred_dense, 
                              tf.concat([[-1], tf.shape(y_pred_dense)[2:]], axis=0))
    
    #batch_size x n_inst_max x height x width x channels -> (batch_size*n_inst_max) x height x width
    y_true_one_hot = tf.reshape(y_true_one_hot,
                                tf.concat([[-1], tf.shape(y_true_one_hot)[2:]], axis=0))    
  • Since some of the images in a batch may have fewer than n_inst_max instances, we need to avoid zero-divison error.
  • To do so we rewrite the loss function
  • First note that

    $$\left\lVert u - \frac{v}{q} \right\rVert = \sqrt{\sum_i\left(u_i - \frac{v_i}{q}\right)^2} = \sqrt{\sum_i \frac{1}{q^2}\left(q\cdot u_i - v_i\right)^2} = \frac{1}{q}\sqrt{\sum_i\left(q\cdot u_i - v_i\right)^2} = \frac{1}{q}\left\lVert q\cdot u - v \right\rVert $$

  • The loss function becomes

    $$\mathcal{L}(\Psi|\mathbf{x}, \mathcal{S}) = \sum_{S \in \mathcal{S}} \frac{1}{\lvert S \rvert^2} \sum_{u \in S}\left\lVert{\lvert S \rvert}\cdot\Psi_u(\mathbf{x}) - \sum_{u \in S}\Psi_u(\mathbf{x})\right\rVert$$

  • In the code below we first find the sum of the Euclidean distance for each the batch_size $\times$ n_inst_max rows in y_pred_dense.
  • Then we select only those from the rows that correspond to an instance before dividing by n_inst_pixels thus avoiding division by zero
In [ ]:
    #find number of pixels in each instance
    #(batch_size*n_inst_max) x height x width -> (batch_size*n_inst_max)
    n_inst_pixels = tf.reduce_sum(y_true_one_hot, axis=[1, 2])
    
    #(batch_size*n_inst_max) x height x width x channels -> (batch_size*n_inst_max) x channels
    embeds_sum = tf.reduce_sum(y_pred_dense, axis=[1,2], keep_dims=True)
    #(batch_size*n_inst_max) x height x width x channels -> (batch_size*n_inst_max) x height x width
    dist = tf.norm(y_pred_dense*n_inst_pixels - embeds_sum, axis=-1)
    
    #keep only the distances for pixels that belong to the instance
    dist_masked = dist*y_true_one_hot
    
    #sum the losses for each instance
    #(batch_size*n_inst_max) x height x width -> (batch_size*n_inst_max)
    dist_sum = tf.reduce_sum(dist_masked, axis=[1,2])
    has_inst_mask = tf.greater(n_inst_pixels, 0)
    
    #select only the elements of dist that correspond to an instance
    losses = (tf.boolean_mask(dist_sum, has_inst_mask)/
                tf.boolean_mask(n_inst_pixels, has_inst_mask)**2)
    
    loss = tf.reduce_sum(losses)
    
    return loss

IOU Score

Implements the evaluation metric described at https://www.kaggle.com/c/airbus-ship-detection#evaluation.

In [ ]:
def iou_score(y_true, y_pred):
    thresholds = tf.constant([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
    beta = 2
    
    n_inst_max = tf.max(y_true) 
    n_pred_max = tf.max(y_pred) 
    
    #B x H X W x I x 1 
    y_true_one_hot = tf.one_hot(y_true, depth=n_inst_max, axis=-1)[...,np.newaxis]
    #B x H x W x 1 x P
    y_pred_one_hot = tf.one_hot(y_pred, depth=n_pred_max, axis=-1)[...,np.newaxis,:]
    
    #B x I x P
    intersection = tf.reduce_sum(y_true*y_pred, axis=[1,2])
    #B x I x P
    union = tf.reduce_sum(y_true, axis=[1,2]) + tf.reduce_sum(y_pred, axis=[1,2]) - intersection
    
    #B x I x P
    iou_masked = intersection/tf.where(tf.greater(union, 0), union, tf.ones_like(union))
    
    #B x I x P x T
    match = tf.greater(iou_masked[...,tf.newaxis], thresholds)
    
    #B x I x T
    inst_match_at_thresh = tf.to_float32(tf.reduce_any(match, axis=[-2]))
    
    #B x P x T
    pred_match_at_thresh = tf.to_float32(tf.reduce_any(match, axis=[-3]))
    
    #T
    tp_at_thresh = tf.reduce_sum(inst_match_at_thresh, axis=[0, 1])
    fn_at_thresh = tf.reduce_sum(1 - inst_match_at_thresh, axis=[0, 1])
    fp_at_thresh = tf.reduce_sum(1 - pred_match_at_thresh, axis=[0, 1])
    
    f2_numerator = (1 + beta**2)*tp_at_thresh
    f2_score = f2_numerator/(f2_numerator + (beta**2)*fn_at_thresh + fp_at_thresh)
    mean_f2_score = tf.reduce_mean(f2_score, axis=0)
    
    return mean_f2_score