Source code for seismic.inversion.wavefield_decomp.plot_nd_batch

#!/usr/bin/env python
# coding: utf-8
Batch plotting a MCMC solution for batch of stations to a single pdf file.

import os
import logging
import json
import copy

import click
from import tqdm

import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from matplotlib.table import table
import rf

from seismic.inversion.wavefield_decomp.runners import load_mcmc_solution
from seismic.inversion.wavefield_decomp.wfd_plot import plot_Nd
from seismic.stream_io import read_h5_stream
from seismic.receiver_fn.rf_deconvolution import rf_iter_deconv

default_fig_width_inches = 6.4
RF_TRIM_WINDOW = (-10.0, 25.0)

[docs]def convert_Vs_to_k(soln, config): """ Transform Vs variable into k variable in MCMC solution. Modifies soln in-place. :param soln: Solution container :type soln: Customized scipy.optimize.OptimizeResult :param config: Solution configuration :type config: dict :return: None """ layers = config['layers'] for i, layer in enumerate(layers): Vp = layer['Vp'] if 'k_range' not in layer: assert 'Vs_range' in layer Vs_range = layer['Vs_range'] layer.update({'k_range': [Vp/Vs_range[1], Vp/Vs_range[0]]}) # end if # Scale Vs content in bounds, x, clusters, samples, bins and distribution to Vp/Vs lb =[2*i + 1][2*i + 1] = Vp/soln.bounds.ub[2*i + 1] soln.bounds.ub[2 * i + 1] = Vp/lb soln.x[:, 2*i + 1] = Vp/soln.x[:, 2*i + 1] for cluster in soln.clusters: cluster[:, 2*i + 1] = Vp/cluster[:, 2*i + 1] # end for soln.samples[:, 2*i + 1] = Vp/soln.samples[:, 2*i + 1] soln.bins[2*i + 1] = np.flip(Vp/soln.bins[2*i + 1]) soln.distribution[2*i + 1] = np.flip(soln.distribution[2*i + 1])
# end for # end func def _compute_rf(data, config, log): st = rf.RFStream() event_ids = config.get("event_ids") src_file = config.get("waveform_file") if event_ids is None: log.error("Unable to generate RF without event IDs") return st # end if if src_file is None: log.error("Unable to generate RF without path to source file") return st # end if if not os.path.isfile(src_file): log.error("Source file {} for trace metadata not found, cannot generate RF".format(src_file)) return # end if net, sta, loc = config["station_id"].split('.') src_waveforms = read_h5_stream(src_file, net, sta, loc) assert data.shape[0] == len(event_ids) for i, event_data in enumerate(data): evid = event_ids[i] src_stream = rf.RFStream([tr for tr in src_waveforms if tr.stats.event_id == evid]) # Z component z_header ='Z')[0].stats su_opts = config["su_energy_opts"] z_header.starttime = z_header.onset + su_opts["time_window"][0] z_header.sampling_rate = su_opts["sampling_rate"] = 1.0/z_header.sampling_rate z_header.npts = event_data.shape[1] assert np.isclose(float(z_header.endtime - z_header.starttime), su_opts["time_window"][1] - su_opts["time_window"][0]) tr = rf.rfstream.RFTrace(event_data[1, :], header=z_header) st += tr # R component r_header = z_header.copy() =[:-1] + 'R' tr = rf.rfstream.RFTrace(event_data[0, :], header=r_header) st += tr # end for st.filter('bandpass', freqmin=0.05, freqmax=1.0, corners=2, zerophase=True) normalize = 0 # Use Z-component for normalization st.rf(rotate=None, method='P', deconvolve='func', func=rf_iter_deconv, normalize=normalize, min_fit_threshold=75.0) return st # end func
[docs]def plot_aux_data(soln, config, log, scale): f = plt.figure(constrained_layout=False, figsize=(6.4*scale, 6.4*scale)) f.suptitle(config["station_id"], y=0.96, fontsize=16) gs = f.add_gridspec(2, 1, left=0.1, right=0.9, bottom=0.1, top=0.87, hspace=0.3, wspace=0.3, height_ratios=[1, 2]) gs_top = gs[0].subgridspec(1, 2) ax0 = f.add_subplot(gs_top[0, 0]) ax1 = f.add_subplot(gs_top[0, 1]) hist_alpha = 0.5 soln_alpha = 0.3 axis_font_size = 6*scale title_font_size = 6*scale nbins = 100 # Plot energy distribution of samples and solution clusters energy_hist, bins = np.histogram(soln.sample_funvals, bins=nbins) energy_hist = energy_hist.astype(float)/np.max(energy_hist)[:-1], energy_hist, width=np.diff(bins), align='edge', color='#808080', alpha=hist_alpha) for i, cluster_energies in enumerate(soln.cluster_funvals): color = 'C' + str(i) cluster_hist, _ = np.histogram(cluster_energies, bins) cluster_hist = cluster_hist.astype(float)/np.max(cluster_hist)[:-1], cluster_hist, width=np.diff(bins), align='edge', color=color, alpha=soln_alpha) # end for ax0.set_title('Energy distribution of random samples and solution clusters', fontsize=title_font_size) ax0.set_xlabel('$E_{SU}$ energy (arb. units)') ax0.set_ylabel('Normalized counts') ax0.tick_params(labelsize=axis_font_size) ax0.xaxis.label.set_size(axis_font_size) ax0.yaxis.label.set_size(axis_font_size) # Plot sorted per-event upwards S-wave energy at top of mantle per solution. # Collect event IDs of worst fit traces and present as table of waveform IDs. event_ids = config["event_ids"] events_best3 = [] events_worst3 = [] for i, esu in enumerate(soln.esu): assert len(esu) == len(event_ids) color = 'C' + str(i) esu_sorted = sorted(zip(esu, event_ids)) events_best3.extend(esu_sorted[:3]) events_worst3.extend(esu_sorted[-3:]) esu_sorted = [e[0] for e in esu_sorted] ax1.plot(esu_sorted, color=color, alpha=soln_alpha) # end for events_best3 = sorted(events_best3) events_worst3 = sorted(events_worst3, reverse=True) best_events_set = set() worst_events_set = set() for _, evid in events_best3: best_events_set.add(evid) if len(best_events_set) >= 3: break # end if # end for for _, evid in events_worst3: worst_events_set.add(evid) if len(worst_events_set) >= 3: break # end if # end for _tab1 = table(ax1, cellText=[[e] for e in best_events_set], colLabels=['BEST'], cellLoc='left', colWidths=[0.35], loc='upper left', edges='horizontal', fontsize=8, alpha=0.6) # alpha broken in matplotlib.table! _tab2 = table(ax1, cellText=[[e] for e in worst_events_set], colLabels=['WORST'], cellLoc='left', colWidths=[0.35], loc='upper right', edges='horizontal', fontsize=8, alpha=0.6) ax1.set_title('Ranked per-event energy for each solution point', fontsize=title_font_size) ax1.set_xlabel('Rank (out of # source events)') ax1.set_ylabel('Event $E_{SU}$ energy (arb. units)') ax1.tick_params(labelsize=axis_font_size) ax1.xaxis.label.set_size(axis_font_size) ax1.yaxis.label.set_size(axis_font_size) # Plot receiver function at base of selected layers axis_font_size = 6*scale max_solutions = config["solver"].get("max_solutions", 3) for layer in config["layers"]: lname = layer["name"] if soln.subsurface and lname in soln.subsurface: base_seismogms = soln.subsurface[lname] # Generate RF and plot. gs_bot = gs[1].subgridspec(max_solutions, 1, hspace=0.4) for i, seismogm in enumerate(base_seismogms): soln_rf = _compute_rf(seismogm, config, log) assert isinstance(soln_rf, rf.RFStream) # Remove any traces for which deconvolution failed. # First, find their unique ID. Then remove all traces with that ID. exclude_ids = set([tr.stats.event_id for tr in soln_rf if len(tr) == 0]) soln_rf = rf.RFStream([tr for tr in soln_rf if tr.stats.event_id not in exclude_ids]) axn = f.add_subplot(gs_bot[i]) if soln_rf: color = 'C' + str(i) rf_R ='R').trim2(RF_TRIM_WINDOW[0], RF_TRIM_WINDOW[1], reftime='onset') num_RFs = len(rf_R) times = rf_R[0].times() + RF_TRIM_WINDOW[0] data = rf_R.stack()[0].data axn.plot(times, data, color=color, alpha=soln_alpha, linewidth=2) axn.text(0.95, 0.95, 'N = {}'.format(num_RFs), fontsize=10, ha='right', va='top', transform=axn.transAxes) axn.set_xlabel('Time (sec)') axn.grid(color='#80808080', linestyle=':') else: axn.annotate('Empty RF plot', (0.5, 0.5), xycoords='axes fraction', ha='center') # end if axn.set_title(' '.join([config["station_id"], lname, 'base RF', '(soln {})'.format(i)]), fontsize=title_font_size, y=0.92, va='top') axn.tick_params(labelsize=axis_font_size) axn.xaxis.label.set_size(axis_font_size) axn.yaxis.label.set_size(axis_font_size) # end for break # TODO: Figure out how to add more layers if needed # end if # end for return f
# end func @click.command() @click.argument('solution_file', type=click.Path(exists=True, dir_okay=False), required=True) @click.option('--output-file', type=click.Path(dir_okay=False), required=True, help='Name of the output PDF file in which to save plots') def main(solution_file, output_file): """Plot all the solutions found in a batch run of N-dimensional solver. Example usage: python seismic/inversion/wavefield_decomp/ --output-file OA_wfd_out.pdf OA_wfd_out.h5 :param solution_file: Input solution filename :param output_file: Output filename """ log = logging.getLogger(__name__) log.setLevel(logging.INFO) soln_config, job_id = load_mcmc_solution(solution_file, logger=log) out_basename, ext = os.path.splitext(output_file) out_basename += '_' + job_id output_file = out_basename + ext with PdfPages(output_file) as pdf: for soln, config in tqdm(soln_config): empty_soln = (soln.x.shape[0] == 0) if empty_soln: continue assert soln.x.shape[-1] == len(config['layers'])*2 vars = [] for layer in config['layers']: layer_name = layer['name'] vars += ['$H_{{{}}}$'.format(layer_name), '$k_{{{}}}$'.format(layer_name)] # end for vars = tuple(vars) # Dump settings page (per station) config_no_evids = copy.deepcopy(config) config_no_evids.pop('event_ids', None) config_text = json.dumps(config_no_evids, indent=4) f = plt.figure(figsize=(default_fig_width_inches, default_fig_width_inches*1.414)) plt.gca().xaxis.set_visible(False) plt.gca().yaxis.set_visible(False) plt.title(config['station_id'] + ' Processing Settings') f.text(0.02, 0.98, 'Settings:\n' + config_text, fontsize=6, va='top', fontname='monospace', transform=plt.gca().transAxes) pdf.savefig(dpi=300, papertype='a3', orientation='portrait') plt.close() # Convert Vs parameter to k = Vp/Vs convert_Vs_to_k(soln, config) p, _, _ = plot_Nd(soln, title=config['station_id'], vars=vars) scale = p.fig.get_size_inches()[0]/default_fig_width_inches # Annotate top left axes with number of events use in the solver. ndims = len( ax = p.axes[0, ndims - 1] ax.text(0.95, 0.95, 'N = {}'.format(soln.num_input_seismograms), fontsize=10, ha='right', va='top', transform=ax.transAxes) pdf.savefig(dpi=300, papertype='a3', orientation='portrait') plt.close() _p = plot_aux_data(soln, config, log, scale) pdf.savefig(dpi=300, papertype='a3', orientation='portrait') plt.close() # end for # end with'Produced file {}'.format(output_file)) # end func if __name__ == "__main__": main() # end if