Source code for seismic.traveltime.mpiops

import logging
import numpy as np
from mpi4py import MPI

logging.basicConfig()
log = logging.getLogger(__name__)

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()


[docs]def run_once(f, *args, **kwargs): """ Run a function on one node and then broadcast result to all. :param str f: The function to be evaluated. Can take arbitrary arguments and return anything or nothing :param str args: Other positional arguments to pass on to f (optional) :param str kwargs: Other named arguments to pass on to f (optional) :return: The value returned by f. :rtype: unknown """ if rank == 0: f_result = f(*args, **kwargs) else: f_result = None result = comm.bcast(f_result, root=0) return result
[docs]def array_split(arr, process=None): """ Convenience function for splitting array elements across MPI processes :param ndarray arr: Numpy array :param int process: Process for which array members are required. If None, MPI.comm.rank is used instead. (optional) :return List corresponding to array members in a process. :rtype: list """ r = process if process else rank return np.array_split(arr, size)[r]