diff --git a/transfers/waterlevels_transducer_transfer.py b/transfers/waterlevels_transducer_transfer.py index c25a9bf2..27c5255e 100644 --- a/transfers/waterlevels_transducer_transfer.py +++ b/transfers/waterlevels_transducer_transfer.py @@ -13,18 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== +import csv +from collections import defaultdict +from io import StringIO from typing import Any import pandas as pd from pandas import Timestamp -from pydantic import ValidationError -from sqlalchemy import insert from sqlalchemy.exc import DatabaseError from sqlalchemy.orm import Session from db import Thing, Deployment, Sensor from db.transducer import TransducerObservation, TransducerObservationBlock -from schemas.transducer import CreateTransducerObservation from transfers.logger import logger from transfers.transferer import Transferer from transfers.util import ( @@ -43,6 +43,11 @@ def __init__(self, *args, **kw): self.groundwater_parameter_id = get_groundwater_parameter_id() self._itertuples_field_map = {} self._df_columns = set() + self._deployment_lookup_chunk_size = int( + self.flags.get("DEPLOYMENT_LOOKUP_CHUNK_SIZE", 2000) + ) + self._copy_chunk_size = int(self.flags.get("COPY_CHUNK_SIZE", 10000)) + self._use_copy_insert = bool(self.flags.get("USE_COPY_INSERT", True)) self._observation_columns = { column.key for column in TransducerObservation.__table__.columns } @@ -68,23 +73,16 @@ def _get_dfs(self): return input_df, cleaned_df def _transfer_hook(self, session: Session) -> None: - gwd = self.cleaned_df.groupby(["PointID"]) - n = len(gwd) + gwd = self.cleaned_df.groupby("PointID", sort=False) + n = gwd.ngroups + deployments_by_pointid = self._prefetch_deployments(session) nodeployments = {} - for i, (index, group) in enumerate(gwd): - pointid = index[0] + for i, (pointid, group) in enumerate(gwd): logger.info( f"Processing PointID: {pointid}. {i + 1}/{n} ({100*(i+1)/n:0.2f}) completed." ) - deployments = ( - session.query(Deployment) - .join(Thing) - .join(Sensor) - .where(Sensor.sensor_type.in_(self._sensor_types)) - .where(Thing.name == pointid) - .all() - ) + deployments = deployments_by_pointid.get(pointid, []) # sort rows by date measured group = group.sort_values(by="DateMeasured") @@ -103,6 +101,7 @@ def _transfer_hook(self, session: Session) -> None: # Get thing_id from the first deployment thing_id = deployments[0].thing_id + deps_sorted = deployments qced_block = TransducerObservationBlock( thing_id=thing_id, @@ -119,54 +118,46 @@ def _transfer_hook(self, session: Session) -> None: (qced_block, qced, "public"), (notqced_block, notqced, "private"), ): - block.start_datetime = rows.DateMeasured.min() - block.end_datetime = rows.DateMeasured.max() - if rows.empty: logger.info(f"no {release_status} records for pointid {pointid}") continue - def _install_ts(value): - if isinstance(value, Timestamp): - return value - if hasattr(value, "date"): - return Timestamp(value) - return Timestamp(pd.to_datetime(value, errors="coerce")) - - deps_sorted = sorted( - deployments, key=lambda d: _install_ts(d.installation_date) - ) - - observations = [ - self._make_observation( - pointid, row, release_status, deps_sorted, nodeployments + block.start_datetime = rows.DateMeasured.iloc[0] + block.end_datetime = rows.DateMeasured.iloc[-1] + if block.end_datetime <= block.start_datetime: + # DB check constraint requires end > start, even for singleton blocks. + block.end_datetime = block.start_datetime + pd.Timedelta( + microseconds=1 ) - for row in rows.itertuples() - ] - - observations = [obs for obs in observations if obs is not None] - if observations: - filtered_observations = [ + deployment_matcher = _DeploymentMatcher(deps_sorted) + + observations = [] + for row in rows.itertuples(): + obs = self._make_observation( + pointid, + row, + release_status, + deployment_matcher, + nodeployments, + ) + if obs is None: + continue + observations.append( {k: v for k, v in obs.items() if k in self._observation_columns} - for obs in observations - ] - session.execute( - insert(TransducerObservation), - filtered_observations, ) + if observations: + self._insert_observations(session, observations) block = self._get_or_create_block(session, block) logger.info( f"Added {len(observations)} water levels {release_status} block" ) - try: - session.commit() - except DatabaseError as e: - session.rollback() - logger.critical( - f"Error committing water levels {release_status} block: {e}" - ) - self._capture_database_error(pointid, e) - continue + try: + session.commit() + except DatabaseError as e: + session.rollback() + logger.critical(f"Error committing water levels for {pointid}: {e}") + self._capture_database_error(pointid, e) + continue # convert nodeployments to errors for pointid, (min_date, max_date) in nodeployments.items(): @@ -176,15 +167,42 @@ def _install_ts(value): "DateMeasured", ) + def _prefetch_deployments(self, session: Session) -> dict[str, list[Deployment]]: + pointids = self.cleaned_df["PointID"].dropna().unique().tolist() + deployments_by_pointid: dict[str, list[Deployment]] = defaultdict(list) + if not pointids: + return {} + + for i in range(0, len(pointids), self._deployment_lookup_chunk_size): + chunk = pointids[i : i + self._deployment_lookup_chunk_size] + deployment_rows = ( + session.query(Thing.name, Deployment) + .join(Deployment, Deployment.thing_id == Thing.id) + .join(Sensor, Sensor.id == Deployment.sensor_id) + .where(Thing.name.in_(chunk)) + .where(Sensor.sensor_type.in_(self._sensor_types)) + .all() + ) + for pointid, deployment in deployment_rows: + deployments_by_pointid[pointid].append(deployment) + + for pointid in deployments_by_pointid: + deployments_by_pointid[pointid].sort( + key=lambda deployment: _installation_timestamp( + deployment.installation_date + ) + ) + return dict(deployments_by_pointid) + def _make_observation( self, pointid: str, row: pd.Series, release_status: str, - deps_sorted: list, + deployment_matcher: "_DeploymentMatcher", nodeployments: dict, ) -> dict | None: - deployment = _find_deployment(row.DateMeasured, deps_sorted) + deployment = deployment_matcher.find(row.DateMeasured) if deployment is None: if pointid not in nodeployments: @@ -210,15 +228,58 @@ def _make_observation( value=row.DepthToWaterBGS, release_status=release_status, ) - obspayload = CreateTransducerObservation.model_validate( - payload - ).model_dump() + if payload["value"] is None or pd.isna(payload["value"]): + self._capture_error( + pointid, + "DepthToWaterBGS is NULL", + "DepthToWaterBGS", + ) + return None + payload["value"] = float(payload["value"]) legacy_payload = self._legacy_payload(row) - return {**obspayload, **legacy_payload} + return {**payload, **legacy_payload} + + except (TypeError, ValueError) as e: + logger.critical(f"Observation build error: {e}") + self._capture_error(pointid, str(e), "DepthToWaterBGS") - except ValidationError as e: - logger.critical(f"Observation validation error: {e.errors()}") - self._capture_validation_error(pointid, e) + def _insert_observations( + self, session: Session, observations: list[dict[str, Any]] + ) -> None: + if not observations: + return + + if not self._use_copy_insert: + raise RuntimeError( + "USE_COPY_INSERT=False is not supported; transducer observations now require COPY inserts." + ) + self._copy_insert_observations(session, observations) + + def _copy_insert_observations( + self, session: Session, observations: list[dict[str, Any]] + ) -> None: + raw_connection = session.connection().connection + cursor = raw_connection.cursor() + table_name = TransducerObservation.__table__.name + columns = [ + key for key in observations[0].keys() if key in self._observation_columns + ] + if not columns: + return + + copy_sql = ( + f"COPY {table_name} ({', '.join(columns)}) " + "FROM STDIN WITH (FORMAT csv, NULL '\\N')" + ) + + for i in range(0, len(observations), self._copy_chunk_size): + chunk = observations[i : i + self._copy_chunk_size] + stream = StringIO() + writer = csv.writer(stream, lineterminator="\n") + for row in chunk: + writer.writerow([_copy_cell(row.get(column)) for column in columns]) + stream.seek(0) + cursor.execute(copy_sql, stream=stream) def _legacy_payload(self, row: pd.Series) -> dict: return {} @@ -356,13 +417,71 @@ def _legacy_payload(self, row: pd.Series) -> dict: } -def _find_deployment(ts, deployments): +def _installation_timestamp(value: Any) -> Timestamp: + if value is None: + return Timestamp.min + if isinstance(value, Timestamp): + return value + if hasattr(value, "date"): + return Timestamp(value) + return Timestamp(pd.to_datetime(value, errors="coerce")) + + +def _copy_cell(value: Any) -> Any: + if value is None: + return r"\N" + if isinstance(value, Timestamp): + if pd.isna(value): + return r"\N" + return value.to_pydatetime().isoformat(sep=" ") + try: + if pd.isna(value): + return r"\N" + except TypeError: + pass + if isinstance(value, bool): + return "t" if value else "f" + if hasattr(value, "isoformat"): + return value.isoformat() + return value + + +class _DeploymentMatcher: + """ + Cursor-based matcher for monotonic time-series rows. + Assumes rows are processed in ascending DateMeasured order. + """ + + def __init__(self, deployments: list[Deployment]): + self._deployments = deployments + self._cursor = 0 + + def find(self, ts: Any) -> Deployment | None: + date = _to_date(ts) + n = len(self._deployments) + while self._cursor < n: + deployment = self._deployments[self._cursor] + start = deployment.installation_date or Timestamp.min.date() + end = deployment.removal_date or Timestamp.max.date() + if date < start: + return None + if date <= end: + return deployment + self._cursor += 1 + return None + + +def _to_date(ts: Any): if hasattr(ts, "date"): - date = ts.date() - else: - date = pd.Timestamp(ts).date() + return ts.date() + return pd.Timestamp(ts).date() + + +def _find_deployment(ts, deployments): + date = _to_date(ts) for d in deployments: - if d.installation_date > date: + start = d.installation_date or Timestamp.min.date() + if start > date: break # because sorted by start end = d.removal_date if d.removal_date else Timestamp.max.date() if end >= date: