6.4 MiB
6.4 MiB
In [61]:
import os import cv2 import numpy import pickle import skimage from scipy.misc import imread from xmljson import badgerfish as bf from xml.etree.ElementTree import fromstring from matplotlib import pyplot, rcParams %matplotlib inline
In [41]:
def negative_regions(positive): """Estimate some negative rectangular regions, where positive is a bitmap marking diatomenes. First extract 200 to 800 (randomly) consequtive rows from positive. From that extract all column sequences without positive marks. Yield those as a list of 2 slices.""" for _ in range(3): height = numpy.random.randint(200,800) ymin = numpy.random.randint(0, positive.shape[0]-width) ymax = ymin + height has_diatomeen = positive[ymin:ymax].sum(0) == 0 for cont in cv2.findContours( has_diatomeen[:,None].astype(numpy.uint8), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE )[1]: if len(cont) != 2: continue xmin, xmax = cont[0][0][1], cont[1][0][1] if xmax - xmin > 200 and xmax - xmin < 800: region = [slice(ymin, ymax), slice(xmin, xmax)] assert positive[region].sum() == 0, "Due to a bug, some positive regions were selected as negative." yield region break
In [42]:
negatives = [] positives = [] for annotation in os.listdir('annotations/'): with open('annotations/' + annotation) as f: data = bf.data(fromstring(f.read()))['annotation'] filename = data['filename']['$'] try: image = imread('images/' + filename) except FileNotFoundError: continue positive = numpy.zeros(image.shape[:2]).astype(numpy.bool) for o in data['object'] if type(data['object'])==list else [data['object']]: region = [ slice(o['bndbox']['ymin']['$'], o['bndbox']['ymax']['$']), slice(o['bndbox']['xmin']['$'], o['bndbox']['xmax']['$']), ] positive[region] = 1 positives.append(image[region]) for region in negative_regions(positive): negatives.append(image[region])
In [45]:
X = negatives + positives y = [False] * len(negatives) + [True] * len(positives)
In [53]:
rcParams['figure.figsize'] = (20,5) axes = (axis for _ in range(10000) for axis in pyplot.subplots(1,4)[1]) for image, class_ in zip(X, y): if numpy.random.rand() < 0.05: axis = next(axes) axis.imshow(image) axis.set_xticks([]) axis.set_yticks([]) h, w, _ = image.shape axis.plot([w, 0, 0, w, w], [0, 0, h, h, 0], 'g' if class_ else 'r', linewidth=8) axis.set_xlim(0, w) axis.set_ylim(0, h) # axis.set_ylabel('negative' if not class_ else 'positive')
/home/herbert/.virtualenvs/medical/lib/python3.5/site-packages/matplotlib/pyplot.py:524: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). max_open_warning, RuntimeWarning)
In [62]:
with open('true_false_data.p3', 'wb') as file: pickle.dump((X,y), file)