Skip to content

Commit

Permalink
Add no-edges option, useful for at least class 8
Browse files Browse the repository at this point in the history
Maybe some other classes too (but not for 0 and 1)
  • Loading branch information
lopuhin committed Mar 1, 2017
1 parent 91dfe82 commit 1757ea9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
7 changes: 4 additions & 3 deletions make_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def main():
'save masks and polygons as png')
arg('--fix', nargs='+', help='{im_id}_{poly_type} format, e.g 6100_1_1_10')
arg('--force-predict', action='store_true')
arg('--no-edges', action='store_true', help='disable prediction on edges')
args = parser.parse_args()
to_fix = set(args.fix or [])
hps = HyperParams(**json.loads(
Expand Down Expand Up @@ -73,7 +74,7 @@ def main():

if to_predict_masks:
predict_masks(args, hps, store, to_predict_masks, args.threshold,
validation=args.validation)
validation=args.validation, no_edges=args.no_edges)
if args.masks_only:
logger.info('Was building masks only, done.')
return
Expand Down Expand Up @@ -129,7 +130,7 @@ def mask_path(store: Path, im_id: str) -> Path:


def predict_masks(args, hps, store, to_predict: List[str], threshold: float,
validation: str=None):
validation: str=None, no_edges: bool=False):
logger.info('Predicting {} masks: {}'
.format(len(to_predict), ', '.join(sorted(to_predict))))
model = Model(hps=hps)
Expand All @@ -148,7 +149,7 @@ def load_im(im_id):

def predict_mask(im):
logger.info(im.id)
return im, model.predict_image_mask(im.data)
return im, model.predict_image_mask(im.data, no_edges=no_edges)

im_masks = map(predict_mask, utils.imap_fixed_output_buffer(
load_im, sorted(to_predict), threads=2))
Expand Down
29 changes: 17 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,10 @@ def _model_path(self, logdir: Path, n_epoch: int) -> Path:
return logdir.joinpath('model-{}'.format(n_epoch))

def predict_image_mask(self, im_data: np.ndarray,
rotate: bool=False) -> np.ndarray:
rotate: bool=False,
no_edges: bool=False,
average_shifts: bool=False
) -> np.ndarray:
self.net.eval()
c, w, h = im_data.shape
b = self.hps.patch_border
Expand All @@ -545,28 +548,29 @@ def predict_image_mask(self, im_data: np.ndarray,
padded[:, -b:, b:-b] = np.flip(im_data[:, -b:, :], 1)
padded[:, :, :b] = np.flip(padded[:, :, b: 2 * b], 2)
padded[:, :, -b:] = np.flip(padded[:, :, -2 * b: -b], 2)
step = s # TODO: // 3
xs = list(range(0, w - s, step)) + [w - s]
ys = list(range(0, h - s, step)) + [h - s]
step = s // 3 if average_shifts else s
margin = b if no_edges else 0
xs = list(range(margin, w - s - margin, step)) + [w - s - margin]
ys = list(range(margin, h - s - margin, step)) + [h - s - margin]
all_xy = [(x, y) for x in xs for y in ys]
out_shape = [self.hps.n_classes, w, h]
pred_mask = np.zeros(out_shape, dtype=np.float32)
pred_per_pixel = np.zeros(out_shape, dtype=np.int16)
n_rot = 4 if rotate else 1

def gen_batch(xy_batch):
inputs = []
for x, y in xy_batch:
def gen_batch(xy_batch_):
inputs_ = []
for x, y in xy_batch_:
# shifted by -b to account for padding
patch = padded[:, x: x + s + 2 * b, y: y + s + 2 * b]
inputs.append(patch)
inputs_.append(patch)
for i in range(1, n_rot):
inputs.append(utils.rotated(patch, i * 90))
return xy_batch, np.array(inputs, dtype=np.float32)
inputs_.append(utils.rotated(patch, i * 90))
return xy_batch_, np.array(inputs_, dtype=np.float32)

for xy_batch, inputs in utils.imap_fixed_output_buffer(
gen_batch, tqdm.tqdm(list(
utils.chunks(all_xy, self.hps.batch_size // (2 * n_rot)))),
utils.chunks(all_xy, self.hps.batch_size // (4 * n_rot)))),
threads=2):
y_pred = self.net(self._var(torch.from_numpy(inputs)))
for idx, mask in enumerate(y_pred.data.cpu().numpy()):
Expand All @@ -576,7 +580,8 @@ def gen_batch(xy_batch):
mask = utils.rotated(mask, -i * 90)
pred_mask[:, x: x + s, y: y + s] += mask / n_rot
pred_per_pixel[:, x: x + s, y: y + s] += 1
assert pred_per_pixel.min() >= 1
if not no_edges:
assert pred_per_pixel.min() >= 1
pred_mask /= np.maximum(pred_per_pixel, 1)
return pred_mask

Expand Down

0 comments on commit 1757ea9

Please sign in to comment.