Conversation
removed extra squeeze in _predict_batch so flowstate returns 2D arrays for h=1 with single unique_id.
There was a problem hiding this comment.
Pull request overview
Fixes a shape bug in the FlowState foundation model’s batch prediction path so h=1 forecasts don’t collapse dimensions (especially for single-series inputs), preventing downstream failures.
Changes:
- Removed an extra
.squeeze()when extracting the median (0.5) quantile forecast so outputs preserve a consistent 2D(batch, h)shape.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| fcst_mean_np = fcst_mean.detach().numpy() | ||
| fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None |
There was a problem hiding this comment.
fcst_mean (and fcst) will be on self.device (CUDA when available), so calling .detach().numpy() will raise on GPU. Convert to CPU first (e.g., .detach().cpu().numpy()) for both fcst_mean_np and fcst_quantiles_np to avoid runtime crashes when CUDA is available.
| fcst_mean_np = fcst_mean.detach().numpy() | |
| fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None | |
| fcst_mean_np = fcst_mean.detach().cpu().numpy() | |
| fcst_quantiles_np = ( | |
| fcst.detach().cpu().numpy() if quantiles is not None else None | |
| ) |
There was a problem hiding this comment.
An alternative solution is setting the force argument in the numpy() method to true.
e.g.
fcst_mean_np = fcst_mean.detach().numpy(force=True)There was a problem hiding this comment.
thanks @spolisar! @Kushagra7777 could we add this fix to the pr?
There was a problem hiding this comment.
pushed the changes.
AzulGarza
left a comment
There was a problem hiding this comment.
thanks @Kushagra7777!
could you add a small test to ensure that the fix works as expected?
sure @AzulGarza |
ruff check . ruff format .
| fcst_mean_np = fcst_mean.detach().numpy() | ||
| fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None |
There was a problem hiding this comment.
thanks @spolisar! @Kushagra7777 could we add this fix to the pr?
adjust conversion to handle device-related issues safely Co-Authored-By: spolisar <spolisar@users.noreply.github.com> Co-Authored-By: azul <azul.garza.r@gmail.com>
| fcst_mean_np = fcst_mean.detach().numpy() | ||
| fcst_mean = fcst[..., supported_quantiles.index(0.5)] | ||
| fcst_mean_np = fcst_mean.detach().numpy(force=True) | ||
| fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None |
removed extra squeeze in _predict_batch so flowstate returns 2D arrays for h=1 with single unique_id.