#!/usr/bin/env python3 """Plot training/test data. Channels appear side by side.""" import os import argparse import json import random import numpy as np import matplotlib.pyplot as plt from icenet import util def plot_sample(outdir, sample, minmax, interactive=False): fig, (ax1, ax2) = plt.subplots(1, 2, sharex='col', sharey='row') fig.suptitle('\n'.join([ "Id: %s (original)" % sample.get('id'), "Iceberg? %s" % sample.get('is_iceberg', '-'), "Incident angle: %s" % repr(sample.get('inc_angle'))]) ) b1_min, b1_max, b2_min, b2_max, _, _ = minmax b1 = (np.array(sample['band_1']).reshape(75, 75) - b1_min) / (b1_max - b1_min) b2 = (np.array(sample['band_2']).reshape(75, 75) - b2_min) / (b2_max - b2_min) ax1.imshow(b1, vmin=0.2, vmax=0.8, aspect='auto') ax1.set_title("HH") ax2.imshow(b2, vmin=0.2, vmax=0.8, aspect='auto') ax2.set_title("HV") #fig.tight_layout() fig.subplots_adjust(top=0.80) if interactive: plt.show() else: fig.savefig(os.path.join(outdir, "%s.png" % sample['id'])) plt.close('all') def plot_angle_hist(outdir, samples, interactive=False): angles = [s['inc_angle'] for s in samples if s['inc_angle'] != 'na'] hist, bins = np.histogram(angles, bins=50) width = 0.7 * (bins[1] - bins[0]) center = (bins[:-1] + bins[1:]) / 2 fig = plt.figure() plt.bar(center, hist, align='center', width=width) plt.title("%d/%d valid angles" % (len(angles), len(samples))) if interactive: plt.show() else: fig.savefig(os.path.join(outdir, "angle_hist.png")) plt.close('all') if __name__ == '__main__': parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( '-o', required=False, help="Output directory for samples." ) parser.add_argument( '-n', type=int, required=False, help="Number of samples to plot (randomized)" ) parser.add_argument( '-i', required=False, action='store_true', help="Show plots (and don't store)" ) parser.add_argument( 'samples_file', help="JSON file with samples" ) args = parser.parse_args() outdir = args.o or '.' if not os.path.isdir(outdir): raise Exception("Output directory does not exist") print("Loading samples ...") with open(args.samples_file) as f: samples = json.load(f) print("%d samples in set" % len(samples)) if args.n: print("Pick %d random samples ..." % args.n) random.shuffle(samples) samples = samples[0:args.n] minmax = util.get_minmax(samples) for i, s in enumerate(samples): print("Plotting sample %d/%d" % (i+1, len(samples))) plot_sample(outdir, s, minmax, interactive=args.i or False) print("Plot angle histogram ..") plot_angle_hist(outdir, samples, interactive=args.i or False) print("Done.")