Research Software Design by Example
{
"csv": "../../data/survey_tidy",
"pandas": "../../data/survey.db",
"sql": "../../data/survey.db",
"sqlmodel": "../../data/survey.db"
}
main
plugin_X
as a moduleread_data
function in that moduledef main():
"""Main driver."""
args = parse_args()
config = json.loads(Path(args.plugins).read_text())
tables = {}
for plugin_stem, plugin_param in config.items():
module = importlib.import_module(f'plugin_{plugin_stem}')
tables[plugin_stem] = module.read_data(plugin_param)
check(tables)
_, values = tables.popitem()
make_figures(args, values['combined'], values['centers'])
read_data
is a list of tablesdef parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('--figdir', type=str, help='output dir')
parser.add_argument('--plugins', type=str, required=True, help='config')
return parser.parse_args()
def check(tables):
"""Check all tables against each other."""
ref_key = None
for key in tables:
if ref_key is None:
ref_key = key
continue
if set(tables[ref_key].keys()) != set(tables[key].keys()):
print(f'mis-match in provided tables {ref_key} != {key}')
else:
for sub_key in tables[ref_key]:
if len(tables[ref_key][sub_key]) != len(tables[key][sub_key]):
print(f'mis-match in {sub_key}: {ref_key} != {key}')
def read_data(csvdir):
"""Read CSV files directly into dataframes."""
raw = [pd.read_csv(filename) for filename in Path(csvdir).glob('*.csv')]
return util.combine_with_pandas(*raw)
def combine_with_pandas(*tables):
"""Combine tables using Pandas."""
combined = pd.concat(tables)
centers = centers_with_pandas(combined)
return {'combined': combined, 'centers': centers}
# Query to select all samples from database in normalized form.
Q_SAMPLES = """
select
surveys.site,
samples.lon,
samples.lat,
samples.reading
from surveys join samples
on surveys.label = samples.label
"""
def read_data(dbfile):
"""Read tables and do calculations directly in SQL."""
con = sqlite3.connect(dbfile)
return {
"combined": pd.read_sql(util.Q_SAMPLES, con),
"centers": pd.read_sql(Q_CENTERS, con),
}
Q_CENTERS = """
select
surveys.site,
sum(samples.lon * samples.reading) / sum(samples.reading) as lon,
sum(samples.lat * samples.reading) / sum(samples.reading) as lat
from surveys join samples
on surveys.label = samples.label
group by surveys.site
"""
class Sites(SQLModel, table=True):
"""Survey sites."""
site: str | None = Field(default=None, primary_key=True)
lon: float
lat: float
surveys: list['Surveys'] = Relationship(back_populates='site_id')
class Surveys(SQLModel, table=True):
"""Surveys done."""
label: int | None = Field(default=None, primary_key=True)
date: date_type
site: str | None = Field(default=None, foreign_key='sites.site')
site_id: Sites | None = Relationship(back_populates='surveys')
samples: list['Samples'] = Relationship(back_populates='label_id')
read_data
function is:def read_data(dbfile):
"""Read database and do calculations with SQLModel ORM."""
url = f"sqlite:///{dbfile}"
engine = create_engine(url)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
combined = list(
(s.label_id.site, s.lon, s.lat, s.reading)
for s in session.exec(select(Samples))
)
combined = pd.DataFrame(
combined,
columns=["site", "lon", "lat", "reading"]
)
return {
"combined": combined,
"centers": util.centers_with_pandas(combined)
}