Skip to content

Commit

Permalink
[tests] added executive test functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bch0w committed May 17, 2022
1 parent cc218e3 commit f16f3e8
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 18 deletions.
16 changes: 11 additions & 5 deletions pyatoa/core/executive.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class Executive:
quantification.
"""
def __init__(self, event_ids, station_codes, config, max_stations=4,
max_events=1, cat="+", log_level="DEBUG", datasets=None,
figures=None, logs=None, adjsrcs=None, ds_fid_template=None):
max_events=1, cat="+", log_level="DEBUG", cwd=None,
datasets=None, figures=None, logs=None, adjsrcs=None,
ds_fid_template=None):
"""
The Executor needs some key information before it can run processing

Expand All @@ -60,6 +61,8 @@ def __init__(self, event_ids, station_codes, config, max_stations=4,
uncommon as it is given to str.split()
:type log_level: str
:param log_level: log level to be given to all underlying loggers
:type cwd: str
:param cwd: active working directory to look for inputs and save output
:type datasets: str
:param datasets: path to save ASDFDataSets. defaults to a subdirectory
'datasets', inside the current working directory.
Expand All @@ -76,16 +79,16 @@ def __init__(self, event_ids, station_codes, config, max_stations=4,
"""
self.config = config

self.station_codes = sorted(station_codes)
self.event_ids = sorted(event_ids)
self.station_codes = station_codes
self.event_ids = event_ids

self.max_stations = max_stations
self.max_events = max_events

# Define a rudimentary path structure to keep main dir. light.
# These can usually be defaults but if working with Pyaflowa, allow
# them to be overwritten
self.cwd = os.getcwd()
self.cwd = cwd or os.getcwd()
self.datasets = datasets or os.path.join(self.cwd, "datasets")
self.figures = figures or os.path.join(self.cwd, "figures")
self.logs = logs or os.path.join(self.cwd, "logs")
Expand Down Expand Up @@ -126,8 +129,11 @@ def check(self):
# Ensure entries are lists
if not isinstance(self.station_codes, list):
self.station_codes = [self.station_codes]
self.station_codes = sorted(self.station_codes)

if not isinstance(self.event_ids, list):
self.event_ids = [self.event_ids]
self.event_ids = sorted(self.event_ids)

for sta in self.station_codes:
assert(len(sta.split(".")) == 4), (f"station codes must be in "
Expand Down
1 change: 0 additions & 1 deletion pyatoa/core/gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import traceback
import warnings


from pyasdf import ASDFWarning
from obspy.core.event import Event
from obspy.clients.fdsn import Client
Expand Down
63 changes: 51 additions & 12 deletions pyatoa/tests/test_executive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Test the functionalities of the Executive class which runs many Managers
"""
import os
import pytest
import numpy as np
from pyatoa import Config, Executive, logger


Expand All @@ -15,7 +17,7 @@ def events():
"""
A list of GeoNet event ids used for gathering metadata from GeoNet client
"""
event_ids = ["2018p130600", "2012p242656", "2017p015402", "2228901"]
event_ids = ["2018p130600", "2012p242656"]
return event_ids


Expand All @@ -38,32 +40,69 @@ def config(events):
A preset Config object that specifies where to grab data from, which
already exists in the test data directory
"""
syn_path = "./test_data/test_executive/{}"
synthetics = [syn_path.format(_) for _ in events]
# syn_path = "./test_data/test_executive/{}"
# synthetics = [os.path.abspath(syn_path.format(_)) for _ in events]

syn_path = [os.path.abspath("./test_data/test_executive/")]

cfg = Config(iteration=1, step_count=0, min_period=10, max_peropd=30,
client="GEONET", pyflex_preset="default",
adj_src_type="cc_traveltime_misfit",
paths={"synthetics": synthetics})
paths={"synthetics": syn_path})
return cfg


def test_single_event_single_station_no_concurrent(config, events, stations):
def test_executive_single_event_single_station_no_concurrent(tmpdir, config,
events, stations):
"""
Attempt a single event single stationprocessing
Attempt a single event single station processing without using concurrency
"""
exc = Executive(event_ids=events[0], station_codes=stations[0],
config=config)
misfit = exc.process_station(f"{events[0]}-{stations[0]}")
config=config, cwd=tmpdir.strpath)
misfit = exc.process_station(f"{events[0]}{exc.cat}{stations[0]}")
assert(pytest.approx(misfit, .001) == 1.6696)


def test_executive_single_event_single_station(config, events, stations):
def test_executive_single_event_single_station(tmpdir, config, events,
stations):
"""
Attempt a single event single stationprocessing
Attempt a single event single station processing with concurrency
"""
exc = Executive(event_ids=events[0], station_codes=stations[0],
config=config)
config=config, cwd=tmpdir.strpath)
misfits = exc.process()
misfit = misfits[events[0]][stations[0]]
assert(pytest.approx(misfit, .001) == 1.6696)


def test_executive_single_event_multi_station(tmpdir, config, events,
stations):
"""
Attempt a single event multi station processing with concurrency
"""
exc = Executive(event_ids=events[0], station_codes=stations,
config=config, cwd=tmpdir.strpath, max_events=1,
max_stations=os.cpu_count())
misfits = exc.process()
assert(len(misfits) == 1)
assert(len(misfits[events[0]]) == len(stations))
misfit = misfits[events[0]][stations[4]]
assert(pytest.approx(misfit, .001) == 0.76983)


pytest.set_trace()
# def test_executive_multi_event_multi_station(tmpdir, config, events,
# stations):
# """
# !!! This test is causing my computer to crash, not sure why, must rework
#
# Attempt a single event multi station processing with concurrency.
# Only do 2 events and 2 stations max to avoid crashing out system
# """
# exc = Executive(event_ids=events, station_codes=stations,
# config=config, cwd=tmpdir.strpath, max_events=2,
# max_stations=2)
# misfits = exc.process()
# assert(len(misfits) == 2)
# assert(len(misfits[events[1]]) == len(stations))
# misfit = misfits[events[1]][stations[2]]
# assert(pytest.approx(misfit, .001) == 12.242)

0 comments on commit f16f3e8

Please sign in to comment.