Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/siac/priors/brdf/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

from __future__ import annotations

from typing import Literal
from typing import Any, Literal

import numpy as np
import xarray as xr

from siac._rust import RossThickLiSparse as _RustKernels


class BRDFKernels:
"""
Expand All @@ -44,7 +42,7 @@ def __init__(
):
self.hb = hb
self.br = br
self._rust_kernels = _RustKernels(hb, br)
self._rust_kernels: Any | None = None # lazily initialized on first compute()

def compute(
self,
Expand Down Expand Up @@ -83,6 +81,10 @@ def compute(
vza_in = np.ascontiguousarray(vza_np.reshape(1, -1), dtype=np.float64)
sza_in = np.ascontiguousarray(sza_np.reshape(1, -1), dtype=np.float64)
raa_in = np.ascontiguousarray(raa_np.reshape(1, -1), dtype=np.float64)
if self._rust_kernels is None:
from siac._rust import RossThickLiSparse as _RustKernels # noqa: PLC0415 - lazy; siac._rust is optional at import time

self._rust_kernels = _RustKernels(self.hb, self.br)
k_vol, k_geo = self._rust_kernels.compute(vza_in, sza_in, raa_in)
k_vol = np.asarray(k_vol, dtype=np.float64).reshape(original_shape)
k_geo = np.asarray(k_geo, dtype=np.float64).reshape(original_shape)
Expand Down
4 changes: 2 additions & 2 deletions python/siac/priors/earthdata_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import numpy as np
import rioxarray # noqa: F401
import xarray as xr
from pyhdf.SD import SD, SDC

from siac.io.reprojection import transform_bounds

if TYPE_CHECKING:
Expand Down Expand Up @@ -457,6 +455,8 @@ def reproject_native_to_target(

def read_hdf4_dataset(path: str | Path, dataset_name: str) -> tuple[np.ndarray, dict[str, Any]]:
"""Read an HDF4 SDS plus decoded attributes."""
from pyhdf.SD import SD, SDC # noqa: PLC0415 - lazy import; pyhdf is optional

sd = SD(str(path), SDC.READ)
sds = sd.select(dataset_name)
return np.asarray(sds.get()), {key: decode_attr(value) for key, value in sds.attributes().items()}
Expand Down
14 changes: 12 additions & 2 deletions python/siac/priors/surface/brdf_whittaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import xarray as xr

from siac._rust import whittaker_smooth_cube
from siac.core.types import BRDFKernelWeights, GeometryAngles, SurfacePrior
from siac.priors.brdf.kernels import BRDFKernels, compute_reflectance
from siac.priors.surface.kernel_model import KernelModelDeriver
Expand All @@ -20,6 +19,17 @@
logger = logging.getLogger(__name__)


def _whittaker_smooth_cube(
reflectance: np.ndarray,
weights: np.ndarray,
lam: float,
) -> np.ndarray:
"""Thin wrapper that defers the Rust extension import to first use."""
from siac._rust import whittaker_smooth_cube # noqa: PLC0415 - lazy; siac._rust is optional at import time

return whittaker_smooth_cube(reflectance, weights, lam) # type: ignore[no-any-return]


class BRDFWhittakerDeriver(KernelModelDeriver):
"""Derive a sensing-date surface prior from a temporal BRDF stack."""

Expand Down Expand Up @@ -88,7 +98,7 @@ def compute_surface_prior(
np.divide(weights, max_weight, out=normalized_weights, where=max_weight > 0.0)
weights = normalized_weights

smoothed = whittaker_smooth_cube(
smoothed = _whittaker_smooth_cube(
np.ascontiguousarray(reflectance_values, dtype=np.float32),
np.ascontiguousarray(weights, dtype=np.float32),
self.temporal_lambda,
Expand Down
22 changes: 20 additions & 2 deletions python/siac/rt/emulator/two_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

import logging
from pathlib import Path
from typing import Any

import numpy as np
import xarray as xr

from siac._rust import TwoLayerNN as _RustNN
from siac.core.types import (
AtmosphericState,
GeometryAngles,
Expand Down Expand Up @@ -423,6 +423,24 @@ def load(
)


def _init_rust_nn(
w1: np.ndarray,
b1: np.ndarray,
w2: np.ndarray,
b2: np.ndarray,
w3: np.ndarray,
b3: np.ndarray,
) -> Any:
"""Lazily import the Rust NN and construct it; defers the hard dependency to first use.

Returns a ``siac._rust.TwoLayerNN`` instance. The return type is declared
as ``Any`` because the Rust type is not available at static-analysis time.
"""
from siac._rust import TwoLayerNN as _RustNN # noqa: PLC0415 - lazy; siac._rust is optional at import time

return _RustNN(w1, b1, w2, b2, w3, b3)


class _BandEmulator:
"""
Internal class for single-band emulator.
Expand Down Expand Up @@ -465,7 +483,7 @@ def _init_rust_emulator(self) -> None:
w3 = np.asarray(self.output_layers[0][0], dtype=np.float32)
b3 = np.asarray(self.output_layers[0][1], dtype=np.float32)

self._rust_nn = _RustNN(w1, b1, w2, b2, w3, b3)
self._rust_nn = _init_rust_nn(w1, b1, w2, b2, w3, b3)

def forward(
self,
Expand Down
38 changes: 31 additions & 7 deletions python/siac/satellite/sentinel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from siac.cloud import build_cloud_classes, classes_to_bool_mask
from siac.core.types import (
SENTINEL2A_CONFIG,
SENTINEL2B_CONFIG,
SENTINEL2C_CONFIG,
GeometryAngles,
SensorConfig,
)
Expand All @@ -33,6 +35,13 @@

logger = logging.getLogger(__name__)

# Built-in fallback configs when SRF workbook cannot be downloaded
_S2_FALLBACK_CONFIGS: dict[str, SensorConfig] = {
"S2A": SENTINEL2A_CONFIG,
"S2B": SENTINEL2B_CONFIG,
"S2C": SENTINEL2C_CONFIG,
}


@register_preprocessor("s2")
class Sentinel2Preprocessor(BaseSatellitePreprocessor):
Expand Down Expand Up @@ -67,15 +76,30 @@ def __init__(self, config: dict[str, Any] | None = None):

@property
def sensor_config(self) -> SensorConfig:
"""Return sensor configuration based on satellite platform."""
"""Return sensor configuration based on satellite platform.

Attempts to load the official SRF-backed config from the local cache or
the remote SentiWiki source. When neither is available (e.g. in an
offline/test environment), falls back to the built-in nominal band
characterisation so that the preprocessor can still operate.
"""
if self._satellite_id is None:
return SENTINEL2A_CONFIG
return load_sensor_config_from_srf(
"MSI",
self._satellite_id,
cache_dir=self.config.get("srf_cache_dir"),
refresh=bool(self.config.get("refresh_srf", False)),
)
try:
return load_sensor_config_from_srf(
"MSI",
self._satellite_id,
cache_dir=self.config.get("srf_cache_dir"),
refresh=bool(self.config.get("refresh_srf", False)),
)
except Exception: # noqa: BLE001 - intentionally broad: any failure (network, I/O, parse) should fall back gracefully
fallback = _S2_FALLBACK_CONFIGS.get(self._satellite_id, SENTINEL2A_CONFIG)
logger.warning(
"Could not load SRF for %s (network or cache unavailable); "
"falling back to built-in nominal band characterisation.",
self._satellite_id,
)
return fallback

def load_toa(self, input_path: str | Path) -> xr.Dataset:
"""Load TOA reflectance from Sentinel-2 SAFE directory."""
Expand Down
16 changes: 10 additions & 6 deletions python/siac/solver/multigrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
import xarray as xr
from scipy import optimize

from siac._rust import (
evaluate_grid_search_cost_cube_with_provider,
interpolate_to_fine_grid,
quadratic_refine_grid_search,
remap_to_coarse_grid,
)
from siac.core.protocols import RTModelBackend
from siac.core.types import (
AtmosphericState,
Expand Down Expand Up @@ -447,6 +441,11 @@ def _candidate_coeff_provider(
xcp_stack[ib] = np.asarray(coeffs.xcp.values, dtype=np.float32)
return xap_stack, xbp_stack, xcp_stack

from siac._rust import ( # noqa: PLC0415 - lazy; siac._rust is optional at import time
evaluate_grid_search_cost_cube_with_provider,
quadratic_refine_grid_search,
)

costs = np.asarray(
evaluate_grid_search_cost_cube_with_provider(
_candidate_coeff_provider,
Expand Down Expand Up @@ -549,6 +548,11 @@ def _resample_field(
return field

data = np.ascontiguousarray(field, dtype=np.float64)
from siac._rust import ( # noqa: PLC0415 - lazy; siac._rust is optional at import time
interpolate_to_fine_grid,
remap_to_coarse_grid,
)

if target_shape[0] < field.shape[0]:
return np.asarray(remap_to_coarse_grid(data, target_shape[0], target_shape[1]))
return np.asarray(interpolate_to_fine_grid(data, target_shape[0], target_shape[1]))
Expand Down