Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 185 additions & 66 deletions transfers/waterlevels_transducer_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
import csv
from collections import defaultdict
from io import StringIO
from typing import Any

import pandas as pd
from pandas import Timestamp
from pydantic import ValidationError
from sqlalchemy import insert
from sqlalchemy.exc import DatabaseError
from sqlalchemy.orm import Session

from db import Thing, Deployment, Sensor
from db.transducer import TransducerObservation, TransducerObservationBlock
from schemas.transducer import CreateTransducerObservation
from transfers.logger import logger
from transfers.transferer import Transferer
from transfers.util import (
Expand All @@ -43,6 +43,11 @@ def __init__(self, *args, **kw):
self.groundwater_parameter_id = get_groundwater_parameter_id()
self._itertuples_field_map = {}
self._df_columns = set()
self._deployment_lookup_chunk_size = int(
self.flags.get("DEPLOYMENT_LOOKUP_CHUNK_SIZE", 2000)
)
self._copy_chunk_size = int(self.flags.get("COPY_CHUNK_SIZE", 10000))
self._use_copy_insert = bool(self.flags.get("USE_COPY_INSERT", True))
self._observation_columns = {
column.key for column in TransducerObservation.__table__.columns
}
Expand All @@ -68,23 +73,16 @@ def _get_dfs(self):
return input_df, cleaned_df

def _transfer_hook(self, session: Session) -> None:
gwd = self.cleaned_df.groupby(["PointID"])
n = len(gwd)
gwd = self.cleaned_df.groupby("PointID", sort=False)
n = gwd.ngroups
deployments_by_pointid = self._prefetch_deployments(session)
nodeployments = {}
for i, (index, group) in enumerate(gwd):
pointid = index[0]
for i, (pointid, group) in enumerate(gwd):
logger.info(
f"Processing PointID: {pointid}. {i + 1}/{n} ({100*(i+1)/n:0.2f}) completed."
)

deployments = (
session.query(Deployment)
.join(Thing)
.join(Sensor)
.where(Sensor.sensor_type.in_(self._sensor_types))
.where(Thing.name == pointid)
.all()
)
deployments = deployments_by_pointid.get(pointid, [])

# sort rows by date measured
group = group.sort_values(by="DateMeasured")
Expand All @@ -103,6 +101,7 @@ def _transfer_hook(self, session: Session) -> None:

# Get thing_id from the first deployment
thing_id = deployments[0].thing_id
deps_sorted = deployments

qced_block = TransducerObservationBlock(
thing_id=thing_id,
Expand All @@ -119,54 +118,46 @@ def _transfer_hook(self, session: Session) -> None:
(qced_block, qced, "public"),
(notqced_block, notqced, "private"),
):
block.start_datetime = rows.DateMeasured.min()
block.end_datetime = rows.DateMeasured.max()

if rows.empty:
logger.info(f"no {release_status} records for pointid {pointid}")
continue

def _install_ts(value):
if isinstance(value, Timestamp):
return value
if hasattr(value, "date"):
return Timestamp(value)
return Timestamp(pd.to_datetime(value, errors="coerce"))

deps_sorted = sorted(
deployments, key=lambda d: _install_ts(d.installation_date)
)

observations = [
self._make_observation(
pointid, row, release_status, deps_sorted, nodeployments
block.start_datetime = rows.DateMeasured.iloc[0]
block.end_datetime = rows.DateMeasured.iloc[-1]
if block.end_datetime <= block.start_datetime:
# DB check constraint requires end > start, even for singleton blocks.
block.end_datetime = block.start_datetime + pd.Timedelta(
microseconds=1
)
for row in rows.itertuples()
]

observations = [obs for obs in observations if obs is not None]
if observations:
filtered_observations = [
deployment_matcher = _DeploymentMatcher(deps_sorted)

observations = []
for row in rows.itertuples():
obs = self._make_observation(
pointid,
row,
release_status,
deployment_matcher,
nodeployments,
)
if obs is None:
continue
observations.append(
{k: v for k, v in obs.items() if k in self._observation_columns}
for obs in observations
]
session.execute(
insert(TransducerObservation),
filtered_observations,
)
if observations:
self._insert_observations(session, observations)
block = self._get_or_create_block(session, block)
logger.info(
f"Added {len(observations)} water levels {release_status} block"
)
try:
session.commit()
except DatabaseError as e:
session.rollback()
logger.critical(
f"Error committing water levels {release_status} block: {e}"
)
self._capture_database_error(pointid, e)
continue
try:
session.commit()
except DatabaseError as e:
session.rollback()
logger.critical(f"Error committing water levels for {pointid}: {e}")
self._capture_database_error(pointid, e)
continue

# convert nodeployments to errors
for pointid, (min_date, max_date) in nodeployments.items():
Expand All @@ -176,15 +167,42 @@ def _install_ts(value):
"DateMeasured",
)

def _prefetch_deployments(self, session: Session) -> dict[str, list[Deployment]]:
pointids = self.cleaned_df["PointID"].dropna().unique().tolist()
deployments_by_pointid: dict[str, list[Deployment]] = defaultdict(list)
if not pointids:
return {}

for i in range(0, len(pointids), self._deployment_lookup_chunk_size):
chunk = pointids[i : i + self._deployment_lookup_chunk_size]
deployment_rows = (
session.query(Thing.name, Deployment)
.join(Deployment, Deployment.thing_id == Thing.id)
.join(Sensor, Sensor.id == Deployment.sensor_id)
.where(Thing.name.in_(chunk))
.where(Sensor.sensor_type.in_(self._sensor_types))
.all()
)
for pointid, deployment in deployment_rows:
deployments_by_pointid[pointid].append(deployment)

for pointid in deployments_by_pointid:
deployments_by_pointid[pointid].sort(
key=lambda deployment: _installation_timestamp(
deployment.installation_date
)
)
return dict(deployments_by_pointid)

def _make_observation(
self,
pointid: str,
row: pd.Series,
release_status: str,
deps_sorted: list,
deployment_matcher: "_DeploymentMatcher",
nodeployments: dict,
) -> dict | None:
deployment = _find_deployment(row.DateMeasured, deps_sorted)
deployment = deployment_matcher.find(row.DateMeasured)

if deployment is None:
if pointid not in nodeployments:
Expand All @@ -210,15 +228,58 @@ def _make_observation(
value=row.DepthToWaterBGS,
release_status=release_status,
)
obspayload = CreateTransducerObservation.model_validate(
payload
).model_dump()
if payload["value"] is None or pd.isna(payload["value"]):
self._capture_error(
pointid,
"DepthToWaterBGS is NULL",
"DepthToWaterBGS",
)
return None
payload["value"] = float(payload["value"])
legacy_payload = self._legacy_payload(row)
return {**obspayload, **legacy_payload}
return {**payload, **legacy_payload}

except (TypeError, ValueError) as e:
logger.critical(f"Observation build error: {e}")
self._capture_error(pointid, str(e), "DepthToWaterBGS")

except ValidationError as e:
logger.critical(f"Observation validation error: {e.errors()}")
self._capture_validation_error(pointid, e)
def _insert_observations(
self, session: Session, observations: list[dict[str, Any]]
) -> None:
if not observations:
return

if not self._use_copy_insert:
raise RuntimeError(
"USE_COPY_INSERT=False is not supported; transducer observations now require COPY inserts."
)
self._copy_insert_observations(session, observations)

def _copy_insert_observations(
self, session: Session, observations: list[dict[str, Any]]
) -> None:
raw_connection = session.connection().connection
cursor = raw_connection.cursor()
table_name = TransducerObservation.__table__.name
columns = [
key for key in observations[0].keys() if key in self._observation_columns
]
if not columns:
return

copy_sql = (
f"COPY {table_name} ({', '.join(columns)}) "
"FROM STDIN WITH (FORMAT csv, NULL '\\N')"
)

for i in range(0, len(observations), self._copy_chunk_size):
chunk = observations[i : i + self._copy_chunk_size]
stream = StringIO()
writer = csv.writer(stream, lineterminator="\n")
for row in chunk:
writer.writerow([_copy_cell(row.get(column)) for column in columns])
stream.seek(0)
cursor.execute(copy_sql, stream=stream)

def _legacy_payload(self, row: pd.Series) -> dict:
return {}
Expand Down Expand Up @@ -356,13 +417,71 @@ def _legacy_payload(self, row: pd.Series) -> dict:
}


def _find_deployment(ts, deployments):
def _installation_timestamp(value: Any) -> Timestamp:
if value is None:
return Timestamp.min
if isinstance(value, Timestamp):
return value
if hasattr(value, "date"):
return Timestamp(value)
return Timestamp(pd.to_datetime(value, errors="coerce"))


def _copy_cell(value: Any) -> Any:
if value is None:
return r"\N"
if isinstance(value, Timestamp):
if pd.isna(value):
return r"\N"
return value.to_pydatetime().isoformat(sep=" ")
try:
if pd.isna(value):
return r"\N"
except TypeError:
pass
if isinstance(value, bool):
return "t" if value else "f"
if hasattr(value, "isoformat"):
return value.isoformat()
return value


class _DeploymentMatcher:
"""
Cursor-based matcher for monotonic time-series rows.
Assumes rows are processed in ascending DateMeasured order.
"""

def __init__(self, deployments: list[Deployment]):
self._deployments = deployments
self._cursor = 0

def find(self, ts: Any) -> Deployment | None:
date = _to_date(ts)
n = len(self._deployments)
while self._cursor < n:
deployment = self._deployments[self._cursor]
start = deployment.installation_date or Timestamp.min.date()
end = deployment.removal_date or Timestamp.max.date()
if date < start:
return None
if date <= end:
return deployment
self._cursor += 1
return None


def _to_date(ts: Any):
if hasattr(ts, "date"):
date = ts.date()
else:
date = pd.Timestamp(ts).date()
return ts.date()
return pd.Timestamp(ts).date()


def _find_deployment(ts, deployments):
date = _to_date(ts)
for d in deployments:
if d.installation_date > date:
start = d.installation_date or Timestamp.min.date()
if start > date:
break # because sorted by start
end = d.removal_date if d.removal_date else Timestamp.max.date()
if end >= date:
Expand Down