TensorFlow does not to my knowledge enable max pooling with dynamic kernel size. Here I have written a function that does region of interest pooling (RoIPooling), which is used in object detection architectures such as Fast RCNN. This was more of an exercise in writing this function without using a tf.while_loop rather than necessarily an implementation that would be useful in practice.
It is still a work in progress and I have not actually used it in a detection model yet! It works as follows:
import numpy as np
import tensorflow as tf
def roi_pooling(x, bboxes, poolsize, pool_type='max',normed=True):
# x: batch_size(N) x H x W x C
# bboxes: batch_size(N) x n_rois(R) x 4, should be in the range [0,1]
# Will work for single image and roi i.e. x: H x W x C, bboxes: R x 4 by adding single dimensions
# If bboxes are zero-padded where fewer than R RoIs in an image, then function simply returns zeros
assert pool_type in ['max', 'avg']
if pool_type=='max':
pool_fn = tf.reduce_max
else:
pool_fn = tf.reduce_mean
shape = tf.shape(x)[1:3]
# H x W
grid = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]), indexing='ij'), axis=-1)
grid = tf.to_float(grid)/tf.to_float(shape)
#For debugging
if not normed:
grid = grid*tf.to_float(shape)
#N x R 1 x 4
bboxes = bboxes[:, :, None]
# N x R x 1 x 2
step = (bboxes[...,2:] - bboxes[...,:2])/pool_size
# N x R x (pool_size(P) + 1) x 2
ranges = bboxes[..., :2] + tf.range(pool_size+1, dtype=tf.float32)[None, None,:, None]*step
ranges = tf.to_float(ranges)
# grid[None,None,..., None,:] -> 1 x 1 x H x W x 1 x 2
# ranges[:,:,None,None,:-1] -> N x R x 1 x 1 x P x 2
# N x R x H x W x P x 2
in_range_xy = tf.logical_and(grid[None,None,..., None,:] >= ranges[:,:,None,None,:-1],
grid[None,None,...,None,:] < ranges[:,:,None,None,1:])
# in_range_xy[...,None,0] -> N x R x H x W x P x 1
# in_range_xy[...,None,:,1] -> N x R x H x W x 1 x P
# N x R x H x W x P x P
pool_mask = tf.logical_and(in_range_xy[...,None,0], in_range_xy[...,None,:,1])
# N x R x H x W x P x P x 1
pool_mask = tf.to_float(pool_mask[...,None])
# x[:,None,...,None,None,:] -> N x 1 x H x W x 1 x 1 x C
# N x R x P x P x C
pool = pool_fn(x[:,None,...,None,None,:]*pool_mask, axis=(2, 3))
return pool_mask, pool
def crop_im(im, bbox):
x1, y1, x2, y2 = bbox
return im[x1:x2, y1:y2]
h = 50
w = 70
n = 4
shape = np.array([h,w])
bb1 = [10,10,18,18]
bb2 = [21,24,34,35]
bb3 = [1,2, 9, 9]
bb4 = [1,2, 8, 8]
bb5 = [9, 12, 17, 19]
bb6 = [0, 0, 0, 0]
bb = np.array([[bb1, bb2, bb3],
[ bb4, bb5, bb6]])/np.tile(shape[None,None], [1,1,2])
im = np.random.random(size=(2,h,w,1))
img = tf.placeholder(shape=[None,None,None,None], dtype=tf.float32)
bbox = tf.placeholder(shape=[None,None,4], dtype=tf.float32)
pool_size = tf.placeholder(shape=[], dtype=tf.float32)
pool_mask, pool = roi_pooling(img, bbox, pool_size)
with tf.Session() as sess:
pmask, p = sess.run([pool_mask, pool], {img:im, bbox:bb, pool_size:n})
p.shape
crop = crop_im(im[1,...,0], bb4)
crop.round(2)
p[1,0,...,0].round(2)
p.shape
mask2 = pmask[1,...,0]
mask2_dense = np.argmax(np.reshape(mask2, (3,h,w,n**2)), axis=-1)
mask2_all = np.sum(mask2_dense, axis=0)
Note that bb6 does not give rise to a non-zero mask
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(mask2_all)
plt.axis('off');
def plot_grid(values, points, color_values=None, offx=0, offy=0, th=0, maxgrid=None, maxonly=False):
color_values = color_values if color_values is not None else values
for i, (color_value, value, point) in enumerate(zip(color_values, values, points)):
if maxgrid is not None and np.isclose(value, maxgrid[i]):
cl = 'red'
plt.text(point[0]+offx, point[1]+offy,np.round(value,1),
color=cl)
elif not maxonly:
cl = 'white' if color_value < th else 'k'
plt.text(point[0]+offx, point[1]+offy,np.round(value,1),
color=cl)
def points_from_bbox(bbox):
x1, y1, x2, y2 = bbox
grid = np.meshgrid(np.arange(y2-y1), np.arange(x2-x1), indexing='xy')
return np.reshape(np.stack(grid,-1), [-1,2])
crop_mask = crop_im(pmask[1,0,...,0].reshape(h,w,-1).argmax(-1), bb4)
maxes = p[1,0,...,0].reshape(-1)[crop_mask]
points = points_from_bbox(bb4)
n_blocks = len(np.unique(crop_mask))
mids = np.array([points[crop_mask.ravel()==i].mean(axis=0) for i in range(n_blocks) ])
mids = mids[crop_mask].reshape([-1,2])
plt.figure(figsize=(15,3))
plt.subplot(1,5,1)
plt.imshow(crop)
plt.axis('off')
plot_grid(crop.reshape(-1), points, offx=-0.2, offy=0, th=0.6, maxgrid=maxes.ravel())
plt.subplot(1,5,2)
plt.imshow(crop_mask)
plt.axis('off')
plot_grid(crop_mask.reshape(-1), points, offx=-0.2, offy=0, th=10)
plt.subplot(1,5,3)
plt.imshow(crop_mask)
plt.axis('off')
plot_grid(crop.reshape(-1), points, crop_mask.reshape(-1), offx=-0.2, offy=0, th=10,
maxgrid=maxes.ravel())
plt.subplot(1,5,4)
crop_mask = crop_im(pmask[1,0,...,0].reshape(h,w,-1).argmax(-1), bb4)
plt.imshow(crop_mask)
plt.axis('off')
plot_grid(crop.reshape(-1), mids, offx=-0.2, offy=0, th=0.6, maxgrid=maxes.ravel(), maxonly=True)
plt.subplot(1,5,5)
plt.imshow(np.arange(16).reshape((4,4)))
pool_points = points_from_bbox([0,0,n,n])
plt.axis('off')
plot_grid(p[1,0,...,0].ravel(), pool_points, np.arange(16), offx=-0.2, offy=0, th=10)
def equiv_kernel(bbox, n):
x1, y1, x2, y2 = bbox
return (x2-x1)//n, (y2-y1)//n
bb_even1 = [3, 2, 15, 10]
bb_even2 = [25, 20, 37, 28]
bb_even = np.array([bb_even1, bb_even2])#/np.tile(shape, 2)[None]
#Sometimes roipool does not divide grid in even number of blocks even if divisible,
#but using non-normalized co-ordinates seems to make this work fine - in practice
#unlikely to be divisible in equal sized blocks and hopefully differences between
#detections in terms of splitting will even out
pool_mask_even, pool_even = roi_pooling(img, bbox, pool_size, normed=False)
ker = equiv_kernel(bb_even1, n)
maxpool = tf.layers.max_pooling2d(img, pool_size=ker, strides=ker, padding='valid')
with tf.Session() as sess:
pmask_even, p_even = sess.run([pool_mask_even, pool_even], {img: im[:1], bbox:bb_even[None], pool_size:n })
mpool1 = sess.run(maxpool, {img: crop_im(im[0,...,0], bb_even1)[None,...,None] })
mpool2 = sess.run(maxpool, {img: crop_im(im[0,...,0], bb_even2)[None,...,None] })
mpool = np.concatenate([mpool1, mpool2], axis=0)
#Sometimes is not true as does not (due to floating point error maybe ???) as roipool does
#divide grid in even number of blocks even if divisible
np.allclose(mpool, p_even[0])
mask2 = pmask_even[0,...,0]
mask2_dense = np.argmax(np.reshape(mask2, (2,h,w,n**2)), axis=-1)
mask2_all = np.sum(mask2_dense, axis=0)
plt.imshow(mask2_all)
plt.axis('off');
crop_im(pmask_even[0,1,...,0].reshape(h,w,-1).argmax(-1), bb_even2)
crop_im(pmask_even[0,0,...,0].reshape(h,w,-1).argmax(-1), bb_even1)