diff --git a/README.md b/README.md index cd00f18..a89be0e 100644 --- a/README.md +++ b/README.md @@ -92,10 +92,21 @@ Create, render, and publish data visualizations from notebooks or the in-browser Combine visualizations into **drag-and-drop dashboards** with resizable panels, lock/unlock layout, and persistent configuration. Each visualization also has a full **in-browser editor** with Monaco, live preview for JSON backends, template insertion, and data/config tabs. See the [Visualizations Guide](docs/VISUALIZATIONS.md) for SDK usage. +#### Individual Visualization Editor

OpenModelStudio Visualization Framework

+

+ OpenModelStudio Visualization Framework +

+ +#### Dashboard + +

+ OpenModelStudio Visualization Framework +

+ ### Model Registry Browse, install, and manage models from the [Open Model Registry](https://github.com/GACWR/open-model-registry) -- a public GitHub repo that acts as a decentralized model package manager. diff --git a/api/Cargo.lock b/api/Cargo.lock index 58a2233..2f517c5 100644 --- a/api/Cargo.lock +++ b/api/Cargo.lock @@ -955,6 +955,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + [[package]] name = "darling" version = "0.20.11" @@ -2341,6 +2362,7 @@ dependencies = [ "axum-extra", "base64", "chrono", + "csv", "dotenvy", "futures", "hex", @@ -2637,7 +2659,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls 0.23.36", - "socket2 0.5.10", + "socket2 0.6.2", "thiserror 2.0.18", "tokio", "tracing", @@ -2674,9 +2696,9 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] diff --git a/api/Cargo.toml b/api/Cargo.toml index e546805..2653e1e 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -33,6 +33,7 @@ password-hash = "0.5" sha2 = "0.10" hex = "0.4" rustls = { version = "0.23", default-features = false, features = ["ring"] } +csv = "1" [dev-dependencies] tower = { version = "0.5", features = ["util"] } diff --git a/api/src/auth.rs b/api/src/auth.rs index 050a4a3..14d9e09 100644 --- a/api/src/auth.rs +++ b/api/src/auth.rs @@ -46,7 +46,7 @@ pub fn create_access_token( email: email.to_string(), role, iat: now.timestamp(), - exp: (now + Duration::minutes(15)).timestamp(), + exp: (now + Duration::hours(24)).timestamp(), token_type: "access".into(), }; encode( diff --git a/api/src/main.rs b/api/src/main.rs index 8a34ab9..bc4d6c1 100644 --- a/api/src/main.rs +++ b/api/src/main.rs @@ -40,9 +40,13 @@ async fn main() { let llm = Arc::new(LlmService::new(&config)); let k8s = match K8sService::new(&config).await { - Ok(svc) => Some(Arc::new(svc)), + Ok(svc) => { + tracing::info!("K8s service initialized successfully (namespace: {})", config.k8s_namespace); + Some(Arc::new(svc)) + } Err(e) => { - tracing::warn!("K8s client not available: {e}. Running without K8s integration."); + tracing::error!("K8s service initialization FAILED: {e}"); + tracing::error!("Training jobs and workspace pods will NOT work until K8s is properly configured"); None } }; @@ -105,6 +109,7 @@ async fn main() { .route("/models/{id}/code", put(routes::models::update_code)) .route("/models/{id}/run", post(routes::models::run_model)) .route("/models/{id}/versions", get(routes::models::list_versions)) + .route("/models/{id}/experiment-runs", get(routes::models::experiment_runs)) // Training .route("/training/jobs", get(routes::training::list_all_jobs)) .route("/training/start", post(routes::training::start)) @@ -130,6 +135,7 @@ async fn main() { .route("/experiments/{id}/compare", get(routes::experiments::compare)) // Artifacts .route("/jobs/{job_id}/artifacts", get(routes::artifacts::list)) + .route("/models/{model_id}/artifacts", get(routes::artifacts::list_for_model)) .route("/artifacts", post(routes::artifacts::create)) .route("/artifacts/{id}", get(routes::artifacts::get)) .route("/artifacts/{id}", delete(routes::artifacts::delete)) diff --git a/api/src/models/artifact.rs b/api/src/models/artifact.rs index 0561311..e8439b0 100644 --- a/api/src/models/artifact.rs +++ b/api/src/models/artifact.rs @@ -6,7 +6,8 @@ use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct Artifact { pub id: Uuid, - pub job_id: Uuid, + pub job_id: Option, + pub workspace_id: Option, pub name: String, pub artifact_type: String, pub s3_key: String, diff --git a/api/src/models/dataset.rs b/api/src/models/dataset.rs index 2a2ebc9..57960d4 100644 --- a/api/src/models/dataset.rs +++ b/api/src/models/dataset.rs @@ -15,9 +15,10 @@ pub struct Dataset { pub row_count: Option, pub version: i32, pub created_by: Uuid, + pub snapshots: i32, + pub schema: Option, pub created_at: DateTime, pub updated_at: DateTime, - pub snapshots: i32, } #[derive(Debug, Deserialize)] diff --git a/api/src/models/experiment.rs b/api/src/models/experiment.rs index 7e05703..3c08047 100644 --- a/api/src/models/experiment.rs +++ b/api/src/models/experiment.rs @@ -19,7 +19,8 @@ pub struct Experiment { pub struct ExperimentRun { pub id: Uuid, pub experiment_id: Uuid, - pub job_id: Uuid, + pub job_id: Option, + pub model_id: Option, pub parameters: Option, pub metrics: Option, pub created_at: DateTime, @@ -34,7 +35,8 @@ pub struct CreateExperimentRequest { #[derive(Debug, Deserialize)] pub struct AddRunRequest { - pub job_id: Uuid, + pub job_id: Option, + pub model_id: Option, pub parameters: Option, pub metrics: Option, } diff --git a/api/src/routes/artifacts.rs b/api/src/routes/artifacts.rs index 3792287..80fc36c 100644 --- a/api/src/routes/artifacts.rs +++ b/api/src/routes/artifacts.rs @@ -55,6 +55,24 @@ pub async fn create( Ok(Json(artifact)) } +/// List all artifacts for a model (via its jobs) +pub async fn list_for_model( + State(state): State, + AuthUser(_claims): AuthUser, + Path(model_id): Path, +) -> AppResult>> { + let artifacts: Vec = sqlx::query_as( + "SELECT a.* FROM artifacts a + JOIN jobs j ON a.job_id = j.id + WHERE j.model_id = $1 + ORDER BY a.created_at DESC" + ) + .bind(model_id) + .fetch_all(&state.db) + .await?; + Ok(Json(artifacts)) +} + pub async fn download( State(state): State, AuthUser(_claims): AuthUser, diff --git a/api/src/routes/datasets.rs b/api/src/routes/datasets.rs index 92566b7..761bd29 100644 --- a/api/src/routes/datasets.rs +++ b/api/src/routes/datasets.rs @@ -49,13 +49,119 @@ pub async fn get( AuthUser(_claims): AuthUser, Path(id): Path, ) -> AppResult> { - let dataset: Dataset = sqlx::query_as("SELECT * FROM datasets WHERE id = $1") + let mut dataset: Dataset = sqlx::query_as("SELECT * FROM datasets WHERE id = $1") .bind(id) .fetch_one(&state.db) .await?; + + // Lazy backfill: if schema is missing but we have a stored CSV file, extract it now + if dataset.schema.is_none() && dataset.format.eq_ignore_ascii_case("csv") { + if let Some(ref key) = dataset.s3_key { + let path = key.strip_prefix("local:").unwrap_or(key); + if let Ok(bytes) = std::fs::read(path) { + if let Some((schema, row_count)) = extract_csv_schema(&bytes) { + let _ = sqlx::query( + "UPDATE datasets SET schema = $1, row_count = COALESCE(row_count, $2), updated_at = NOW() WHERE id = $3" + ) + .bind(&schema) + .bind(row_count) + .bind(dataset.id) + .execute(&state.db) + .await; + dataset.schema = Some(schema); + if dataset.row_count.is_none() { + dataset.row_count = Some(row_count); + } + } + } + } + } + Ok(Json(dataset)) } +/// Infer the type of a CSV cell value by attempting numeric/bool parsing. +fn infer_cell_type(val: &str) -> &'static str { + if val.is_empty() { + return "string"; + } + if val.parse::().is_ok() { + return "int64"; + } + if val.parse::().is_ok() { + return "float64"; + } + if val.eq_ignore_ascii_case("true") || val.eq_ignore_ascii_case("false") { + return "boolean"; + } + "string" +} + +/// Parse a CSV byte slice and return (schema JSON, row_count). +fn extract_csv_schema(bytes: &[u8]) -> Option<(serde_json::Value, i64)> { + let mut rdr = csv::ReaderBuilder::new() + .has_headers(true) + .from_reader(bytes); + + let headers = rdr.headers().ok()?.clone(); + if headers.is_empty() { + return None; + } + + let num_cols = headers.len(); + // Track best type per column: start with unknown, refine by sampling rows + let mut col_types: Vec> = vec![None; num_cols]; + let mut row_count: i64 = 0; + let sample_limit = 100; // sample first 100 rows for type inference + + for result in rdr.records() { + let record = match result { + Ok(r) => r, + Err(_) => continue, + }; + row_count += 1; + + if row_count <= sample_limit { + for (i, field) in record.iter().enumerate() { + if i >= num_cols { + break; + } + let cell_type = infer_cell_type(field.trim()); + col_types[i] = Some(match col_types[i] { + None => cell_type, + Some(prev) => { + if prev == cell_type { + prev + } else if (prev == "int64" && cell_type == "float64") + || (prev == "float64" && cell_type == "int64") + { + "float64" // promote int ↔ float + } else { + "string" // fall back to string on conflict + } + } + }); + } + } + } + // Count remaining rows after sampling + // (rdr already consumed all records in the loop above) + + let columns: Vec = headers + .iter() + .enumerate() + .map(|(i, name)| { + serde_json::json!({ + "name": name, + "type": col_types.get(i).and_then(|t| *t).unwrap_or("string"), + "nullable": true + }) + }) + .collect(); + + Some((serde_json::Value::Array(columns), row_count)) +} + pub async fn create( State(state): State, AuthUser(claims): AuthUser, @@ -64,7 +170,7 @@ pub async fn create( let dataset_id = Uuid::new_v4(); // If file data is provided (base64), store it to local PVC - let (s3_key, size_bytes) = if let Some(ref data_b64) = req.data { + let (s3_key, size_bytes, inferred_schema, inferred_row_count) = if let Some(ref data_b64) = req.data { use base64::Engine; let bytes = base64::engine::general_purpose::STANDARD .decode(data_b64) @@ -80,14 +186,24 @@ pub async fn create( std::fs::write(&file_path, &bytes) .map_err(|e| AppError::Internal(format!("Failed to write file: {e}")))?; - (Some(format!("local:{}", file_path)), Some(size)) + // Extract schema from CSV files + let (schema, row_count) = if ext == "csv" { + extract_csv_schema(&bytes).unwrap_or((serde_json::Value::Null, 0)) + } else { + (serde_json::Value::Null, 0) + }; + + let schema_opt = if schema.is_null() { None } else { Some(schema) }; + let row_count_opt = if row_count > 0 { Some(row_count) } else { req.row_count }; + + (Some(format!("local:{}", file_path)), Some(size), schema_opt, row_count_opt) } else { - (None, None) + (None, None, None, req.row_count) }; let dataset: Dataset = sqlx::query_as( - "INSERT INTO datasets (id, project_id, name, description, format, s3_key, size_bytes, row_count, version, created_by, created_at, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 1, $9, NOW(), NOW()) RETURNING *" + "INSERT INTO datasets (id, project_id, name, description, format, s3_key, size_bytes, row_count, version, created_by, schema, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 1, $9, $10, NOW(), NOW()) RETURNING *" ) .bind(dataset_id) .bind(req.project_id) @@ -96,8 +212,9 @@ pub async fn create( .bind(&req.format) .bind(&s3_key) .bind(size_bytes) - .bind(req.row_count) + .bind(inferred_row_count) .bind(claims.sub) + .bind(&inferred_schema) .fetch_one(&state.db) .await?; notify(&state.db, claims.sub, "Dataset Created", &format!("Dataset '{}' ({}) uploaded", dataset.name, dataset.format), NotifyType::Success, Some(&format!("/datasets/{}", dataset.id))).await; diff --git a/api/src/routes/experiments.rs b/api/src/routes/experiments.rs index 0a4b378..488ca0a 100644 --- a/api/src/routes/experiments.rs +++ b/api/src/routes/experiments.rs @@ -81,12 +81,13 @@ pub async fn add_run( Json(req): Json, ) -> AppResult> { let run: ExperimentRun = sqlx::query_as( - "INSERT INTO experiment_runs (id, experiment_id, job_id, parameters, metrics, created_at) - VALUES ($1, $2, $3, $4, $5, NOW()) RETURNING *" + "INSERT INTO experiment_runs (id, experiment_id, job_id, model_id, parameters, metrics, created_at) + VALUES ($1, $2, $3, $4, $5, $6, NOW()) RETURNING *" ) .bind(Uuid::new_v4()) .bind(experiment_id) .bind(req.job_id) + .bind(req.model_id) .bind(&req.parameters) .bind(&req.metrics) .fetch_one(&state.db) diff --git a/api/src/routes/models.rs b/api/src/routes/models.rs index a6b3fa8..ecce069 100644 --- a/api/src/routes/models.rs +++ b/api/src/routes/models.rs @@ -248,6 +248,22 @@ pub async fn list_versions( Ok(Json(versions)) } +/// GET /models/{id}/experiment-runs +/// Returns experiment runs linked to this model (from notebook training). +pub async fn experiment_runs( + State(state): State, + AuthUser(_claims): AuthUser, + Path(id): Path, +) -> AppResult>> { + let runs: Vec = sqlx::query_as( + "SELECT * FROM experiment_runs WHERE model_id = $1 ORDER BY created_at DESC" + ) + .bind(id) + .fetch_all(&state.db) + .await?; + Ok(Json(runs)) +} + pub async fn run_model( State(state): State, AuthUser(claims): AuthUser, diff --git a/api/src/routes/training.rs b/api/src/routes/training.rs index c82b7a9..ec73146 100644 --- a/api/src/routes/training.rs +++ b/api/src/routes/training.rs @@ -46,33 +46,33 @@ pub async fn start( .fetch_one(&state.db) .await?; - // Create K8s job (best-effort) - if let Some(ref k8s) = state.k8s { - match k8s - .create_training_job( - job_id, - req.model_id, - &model.framework, - &hardware_tier, - req.dataset_id, - req.hyperparameters.as_ref(), - "training", - ) - .await - { - Ok(k8s_name) => { - sqlx::query("UPDATE jobs SET k8s_job_name = $1, status = $2, started_at = NOW(), updated_at = NOW() WHERE id = $3") - .bind(&k8s_name) - .bind(JobStatus::Running) - .bind(job_id) - .execute(&state.db) - .await?; - } - Err(e) => { - tracing::warn!("K8s job creation failed: {e}"); - } - } - } + // Create K8s job — fail if K8s is not available + let k8s = state.k8s.as_ref().ok_or_else(|| { + AppError::Internal("K8s service not available. Training jobs cannot be created.".into()) + })?; + + let k8s_name = k8s + .create_training_job( + job_id, + req.model_id, + &model.framework, + &hardware_tier, + req.dataset_id, + req.hyperparameters.as_ref(), + "training", + ) + .await + .map_err(|e| { + tracing::error!("K8s job creation failed: {e}"); + AppError::Internal(format!("Failed to create K8s job: {e}")) + })?; + + sqlx::query("UPDATE jobs SET k8s_job_name = $1, status = $2, started_at = NOW(), updated_at = NOW() WHERE id = $3") + .bind(&k8s_name) + .bind(JobStatus::Running) + .bind(job_id) + .execute(&state.db) + .await?; notify(&state.db, claims.sub, "Training Started", &format!("Training job started on {}", hardware_tier), NotifyType::Info, Some(&format!("/training/{}", job_id))).await; Ok(Json(job)) diff --git a/api/src/routes/workspaces.rs b/api/src/routes/workspaces.rs index 95af714..7aa124e 100644 --- a/api/src/routes/workspaces.rs +++ b/api/src/routes/workspaces.rs @@ -58,7 +58,12 @@ pub async fn launch( .map_err(|e| AppError::Internal(format!("Failed to create workspace token: {e}")))?; let (pod_name, access_url) = if let Some(ref k8s) = state.k8s { - k8s.create_workspace_pod(ws_id, &docker_image, &hardware_tier, req.project_id, &workspace_token) + // Create a persistent volume for workspace files + let pvc_name = k8s.create_workspace_pvc(ws_id) + .await + .map_err(|e| AppError::Internal(format!("K8s PVC error: {e}")))?; + + k8s.create_workspace_pod(ws_id, &docker_image, &hardware_tier, req.project_id, &workspace_token, &pvc_name) .await .map_err(|e| AppError::Internal(format!("K8s error: {e}")))? } else { @@ -111,9 +116,11 @@ pub async fn stop( if let (Some(ref k8s), Some(ref pod_name)) = (&state.k8s, &ws.pod_name) { let _ = k8s.delete_pod(pod_name).await; + // Also clean up the PVC since this is a permanent delete + let _ = k8s.delete_workspace_pvc(ws.id).await; } - sqlx::query("UPDATE workspaces SET status = 'stopped', updated_at = NOW() WHERE id = $1") + sqlx::query("DELETE FROM workspaces WHERE id = $1") .bind(id) .execute(&state.db) .await?; diff --git a/api/src/services/k8s.rs b/api/src/services/k8s.rs index 0f4ba8a..fcaa4b6 100644 --- a/api/src/services/k8s.rs +++ b/api/src/services/k8s.rs @@ -2,6 +2,8 @@ use k8s_openapi::api::batch::v1::Job as K8sJob; use k8s_openapi::api::batch::v1::JobSpec; use k8s_openapi::api::core::v1::{ Container, EnvVar, PodSpec, PodTemplateSpec, Pod, ResourceRequirements, + PersistentVolumeClaim, PersistentVolumeClaimSpec, Volume, VolumeMount, + PersistentVolumeClaimVolumeSource, Service, ServicePort, ServiceSpec, }; use k8s_openapi::apimachinery::pkg::api::resource::Quantity; @@ -194,6 +196,44 @@ impl K8sService { Ok(0) } + /// Create a PersistentVolumeClaim for a workspace's working directory + pub async fn create_workspace_pvc(&self, workspace_id: Uuid) -> Result { + let pvc_name = format!("oms-ws-{}-data", workspace_id); + let pvc = PersistentVolumeClaim { + metadata: k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta { + name: Some(pvc_name.clone()), + namespace: Some(self.namespace.clone()), + labels: Some(BTreeMap::from([ + ("app".to_string(), "openmodelstudio".to_string()), + ("workspace-id".to_string(), workspace_id.to_string()), + ])), + ..Default::default() + }, + spec: Some(PersistentVolumeClaimSpec { + access_modes: Some(vec!["ReadWriteOnce".to_string()]), + resources: Some(k8s_openapi::api::core::v1::VolumeResourceRequirements { + requests: Some(BTreeMap::from([ + ("storage".to_string(), Quantity("5Gi".to_string())), + ])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let pvcs: Api = Api::namespaced(self.client.clone(), &self.namespace); + pvcs.create(&PostParams::default(), &pvc).await?; + Ok(pvc_name) + } + + /// Delete a workspace PVC (only on permanent workspace deletion) + pub async fn delete_workspace_pvc(&self, workspace_id: Uuid) -> Result<(), kube::Error> { + let pvc_name = format!("oms-ws-{}-data", workspace_id); + let pvcs: Api = Api::namespaced(self.client.clone(), &self.namespace); + let _ = pvcs.delete(&pvc_name, &DeleteParams::default()).await; + Ok(()) + } + /// Create a workspace pod (JupyterLab) with a NodePort Service pub async fn create_workspace_pod( &self, @@ -202,6 +242,7 @@ impl K8sService { hardware_tier: &str, project_id: Uuid, workspace_token: &str, + pvc_name: &str, ) -> Result<(String, String), kube::Error> { let pod_name = format!("oms-ws-{}", workspace_id); let svc_name = format!("oms-ws-{}-svc", workspace_id); @@ -288,8 +329,21 @@ impl K8sService { container_port: 8888, ..Default::default() }]), + volume_mounts: Some(vec![VolumeMount { + name: "workspace-data".to_string(), + mount_path: "/home/jovyan/work".to_string(), + ..Default::default() + }]), ..Default::default() }], + volumes: Some(vec![Volume { + name: "workspace-data".to_string(), + persistent_volume_claim: Some(PersistentVolumeClaimVolumeSource { + claim_name: pvc_name.to_string(), + read_only: Some(false), + }), + ..Default::default() + }]), ..Default::default() }), ..Default::default() diff --git a/db/init.sql b/db/init.sql index 53c74c8..b2adcad 100644 --- a/db/init.sql +++ b/db/init.sql @@ -112,6 +112,7 @@ CREATE TABLE IF NOT EXISTS datasets ( version INT NOT NULL DEFAULT 1, created_by UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, snapshots INT NOT NULL DEFAULT 0, + schema JSONB, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now() ); @@ -194,7 +195,8 @@ CREATE TABLE IF NOT EXISTS experiments ( CREATE TABLE IF NOT EXISTS experiment_runs ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), experiment_id UUID NOT NULL REFERENCES experiments(id) ON DELETE CASCADE, - job_id UUID NOT NULL REFERENCES jobs(id) ON DELETE CASCADE, + job_id UUID REFERENCES jobs(id) ON DELETE SET NULL, + model_id UUID REFERENCES models(id) ON DELETE SET NULL, parameters JSONB, metrics JSONB, created_at TIMESTAMPTZ NOT NULL DEFAULT now() diff --git a/db/migrations/005_dataset_schema.sql b/db/migrations/005_dataset_schema.sql new file mode 100644 index 0000000..34fde45 --- /dev/null +++ b/db/migrations/005_dataset_schema.sql @@ -0,0 +1,2 @@ +-- Add schema JSONB column to datasets table for storing inferred column info. +ALTER TABLE datasets ADD COLUMN IF NOT EXISTS schema JSONB; diff --git a/db/migrations/006_experiment_run_nullable_job.sql b/db/migrations/006_experiment_run_nullable_job.sql new file mode 100644 index 0000000..0816945 --- /dev/null +++ b/db/migrations/006_experiment_run_nullable_job.sql @@ -0,0 +1,6 @@ +-- Make job_id nullable in experiment_runs for notebook-based runs without a K8s job. +ALTER TABLE experiment_runs ALTER COLUMN job_id DROP NOT NULL; +-- Change cascade to SET NULL so deleting a job doesn't delete experiment runs. +ALTER TABLE experiment_runs DROP CONSTRAINT IF EXISTS experiment_runs_job_id_fkey; +ALTER TABLE experiment_runs ADD CONSTRAINT experiment_runs_job_id_fkey + FOREIGN KEY (job_id) REFERENCES jobs(id) ON DELETE SET NULL; diff --git a/db/migrations/007_experiment_run_model_id.sql b/db/migrations/007_experiment_run_model_id.sql new file mode 100644 index 0000000..3be2dd4 --- /dev/null +++ b/db/migrations/007_experiment_run_model_id.sql @@ -0,0 +1,2 @@ +-- Add model_id to experiment_runs so in-process training (no K8s job) can link to the model +ALTER TABLE experiment_runs ADD COLUMN IF NOT EXISTS model_id UUID REFERENCES models(id) ON DELETE SET NULL; diff --git a/deploy/k8s/rbac.yaml b/deploy/k8s/rbac.yaml index a7c408e..8c30c04 100644 --- a/deploy/k8s/rbac.yaml +++ b/deploy/k8s/rbac.yaml @@ -21,6 +21,9 @@ rules: - apiGroups: [""] resources: ["services"] verbs: ["create", "get", "list", "delete"] + - apiGroups: [""] + resources: ["persistentvolumeclaims"] + verbs: ["create", "get", "list", "delete"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding diff --git a/docs/screenshots/oms-screenshot4.png b/docs/screenshots/oms-screenshot4.png new file mode 100644 index 0000000..7bc9d13 Binary files /dev/null and b/docs/screenshots/oms-screenshot4.png differ diff --git a/docs/screenshots/oms-screenshot5.png b/docs/screenshots/oms-screenshot5.png new file mode 100644 index 0000000..d811ed4 Binary files /dev/null and b/docs/screenshots/oms-screenshot5.png differ diff --git a/sdk/python/openmodelstudio/client.py b/sdk/python/openmodelstudio/client.py index db87917..7473a08 100644 --- a/sdk/python/openmodelstudio/client.py +++ b/sdk/python/openmodelstudio/client.py @@ -1599,6 +1599,7 @@ def add_experiment_run( self, experiment_id: str, job_id: str = None, + model_id: str = None, parameters: dict = None, metrics: dict = None, ) -> dict: @@ -1609,15 +1610,22 @@ def add_experiment_run( openmodelstudio.add_experiment_run(exp["id"], job_id=job["id"], parameters={"lr": 0.001}, metrics={"accuracy": 0.95}) + # For in-process training (no K8s job), use model_id instead: + openmodelstudio.add_experiment_run(exp["id"], model_id=mid, + parameters={"lr": 0.001}, metrics={"accuracy": 0.95}) + Args: experiment_id: UUID of the experiment - job_id: UUID of the associated training job + job_id: UUID of the associated training job (for K8s jobs) + model_id: UUID of the associated model (for in-process training) parameters: Dict of hyperparameters used in this run metrics: Dict of final metrics for this run """ body = {} if job_id: body["job_id"] = job_id + if model_id: + body["model_id"] = model_id if parameters: body["parameters"] = parameters if metrics: diff --git a/sdk/python/openmodelstudio/model.py b/sdk/python/openmodelstudio/model.py index c0914e5..ff5e084 100644 --- a/sdk/python/openmodelstudio/model.py +++ b/sdk/python/openmodelstudio/model.py @@ -356,7 +356,7 @@ def get_experiment(experiment_id: str) -> dict: def add_experiment_run( - experiment_id: str, job_id: str = None, + experiment_id: str, job_id: str = None, model_id: str = None, parameters: dict = None, metrics: dict = None, ) -> dict: """Add a run to an experiment. @@ -365,9 +365,14 @@ def add_experiment_run( openmodelstudio.add_experiment_run(exp["id"], job_id=job["id"], parameters={"lr": 0.001}, metrics={"accuracy": 0.95}) + + # For in-process training (no K8s job), use model_id: + openmodelstudio.add_experiment_run(exp["id"], model_id=mid, + parameters={"lr": 0.001}, metrics={"accuracy": 0.95}) """ return _get_client().add_experiment_run( - experiment_id, job_id=job_id, parameters=parameters, metrics=metrics, + experiment_id, job_id=job_id, model_id=model_id, + parameters=parameters, metrics=metrics, ) diff --git a/web/src/app/dashboards/[id]/page.tsx b/web/src/app/dashboards/[id]/page.tsx index 6735768..213b754 100644 --- a/web/src/app/dashboards/[id]/page.tsx +++ b/web/src/app/dashboards/[id]/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect, useCallback, useMemo } from "react"; +import { useState, useEffect, useCallback, useMemo, useRef } from "react"; import { useParams, useRouter } from "next/navigation"; import { AppShell } from "@/components/layout/app-shell"; import { AnimatedPage } from "@/components/shared/animated-page"; @@ -8,7 +8,7 @@ import { GlassCard } from "@/components/shared/glass-card"; import { EmptyState } from "@/components/shared/empty-state"; import { ErrorState } from "@/components/shared/error-state"; import { CardSkeleton } from "@/components/shared/loading-skeleton"; -import { VizRenderer } from "@/components/shared/viz-renderer"; +import { VizRenderer, downloadVisualization } from "@/components/shared/viz-renderer"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { motion, AnimatePresence } from "framer-motion"; @@ -23,6 +23,7 @@ import { Maximize2, Lock, Unlock, + Download, } from "lucide-react"; import { api } from "@/lib/api"; import { toast } from "sonner"; @@ -114,6 +115,9 @@ export default function DashboardDetailPage() { // Dashboard panel layout const [panels, setPanels] = useState([]); + // Panel container refs for download + const panelRefs = useRef>(new Map()); + // Add panel dialog const [addPanelOpen, setAddPanelOpen] = useState(false); const [selectedVizId, setSelectedVizId] = useState(""); @@ -442,6 +446,22 @@ export default function DashboardDetailPage() { )}
+ + + + {/* Recent Jobs Summary */} {jobs.length > 0 && ( @@ -481,55 +551,251 @@ export default function ModelDetailPage() { + + + + {artifacts.length === 0 ? ( + + ) : ( +
+ {artifacts.map((a, i) => ( + +
+

{a.name}

+

+ {a.artifact_type} · {formatBytes(a.size_bytes)} ·{" "} + {new Date(a.created_at).toLocaleDateString(undefined, { + year: "numeric", month: "short", day: "numeric", + })} +

+
+ + + +
+ ))} +
+ )} +
+
+
+ - {metricNames.length === 0 ? ( -
- - - Training Loss - - - - - - - - Accuracy - - - - - -
- ) : ( -
- {metricNames.map((name) => { - const color = name.toLowerCase().includes("loss") - ? "#ef4444" - : name.toLowerCase().includes("acc") - ? "#10b981" - : "#d4d4d4"; - return ( - + {(() => { + // Collect experiment run metrics for this model + const expMetrics = experimentRuns + .filter((r) => r.metrics && Object.keys(r.metrics).length > 0) + .map((r) => r.metrics as Record); + const hasJobMetrics = metricNames.length > 0; + const hasExpMetrics = expMetrics.length > 0; + + if (!hasJobMetrics && !hasExpMetrics) { + return ( +
+ - {name.replace(/_/g, " ")} + Training Loss - + - ); - })} -
- )} + + + Accuracy + + + + + +
+ ); + } + + // Merge: get all unique metric keys from experiment runs + const expMetricKeys = hasExpMetrics + ? [...new Set(expMetrics.flatMap((m) => Object.keys(m)))] + .sort((a, b) => { + // Sort: train first, then val, then test, then others + const order = (k: string) => + k.startsWith("train") ? 0 : k.startsWith("val") ? 1 : k.startsWith("test") ? 2 : k.startsWith("best") ? 3 : 4; + return order(a) - order(b); + }) + : []; + + // Color function + const metricColor = (key: string) => { + const k = key.toLowerCase(); + if (k.includes("loss")) return "#ef4444"; + if (k.includes("train")) return "#8b5cf6"; + if (k.includes("val")) return "#f59e0b"; + if (k.includes("test")) return "#10b981"; + if (k.includes("f1")) return "#3b82f6"; + if (k.includes("acc")) return "#10b981"; + return "#d4d4d4"; + }; + + return ( +
+ {/* Experiment run metrics (snapshot from notebook) */} + {hasExpMetrics && ( + <> + + + + Experiment Metrics + + + +
+ {expMetricKeys.map((key) => { + const val = expMetrics[0][key]; + if (val === undefined) return null; + const isPercent = key.includes("accuracy") || key.includes("acc") || key.includes("f1"); + return ( + +

+ {key.replace(/_/g, " ")} +

+

+ {isPercent ? `${(val * 100).toFixed(1)}%` : val.toFixed(4)} +

+
+ ); + })} +
+
+
+ + {/* Bar chart for accuracy/f1 metrics */} + {(() => { + const accKeys = expMetricKeys.filter( + (k) => k.includes("accuracy") || k.includes("acc") || k.includes("f1") + ); + if (accKeys.length === 0) return null; + const maxVal = Math.max(...accKeys.map((k) => expMetrics[0][k] ?? 0)); + return ( + + + Performance Summary + + +
+ {accKeys.map((key, i) => { + const val = expMetrics[0][key] ?? 0; + const pct = maxVal > 0 ? (val / maxVal) * 100 : 0; + return ( + +
+ + {key.replace(/_/g, " ")} + + + {(val * 100).toFixed(2)}% + +
+
+ +
+
+ ); + })} +
+ + {/* Parameters if available */} + {experimentRuns[0]?.parameters && Object.keys(experimentRuns[0].parameters).length > 0 && ( +
+

Hyperparameters

+
+ {Object.entries(experimentRuns[0].parameters as Record).map(([k, v]) => ( +
+ {k} +

{String(v)}

+
+ ))} +
+
+ )} +
+
+ ); + })()} + + )} + + {/* Time-series job metrics (from K8s training) */} + {hasJobMetrics && ( +
+ {metricNames.map((name) => { + const color = name.toLowerCase().includes("loss") + ? "#ef4444" + : name.toLowerCase().includes("acc") + ? "#10b981" + : "#d4d4d4"; + return ( + + + {name.replace(/_/g, " ")} + + + + + + ); + })} +
+ )} +
+ ); + })()}
diff --git a/web/src/app/settings/page.tsx b/web/src/app/settings/page.tsx index d236d8c..feafffe 100644 --- a/web/src/app/settings/page.tsx +++ b/web/src/app/settings/page.tsx @@ -94,6 +94,7 @@ export default function SettingsPage() { const [copiedKey, setCopiedKey] = useState(null); const [genKeyOpen, setGenKeyOpen] = useState(false); const [genKeyName, setGenKeyName] = useState(""); + const [generatedKey, setGeneratedKey] = useState(null); const [generating, setGenerating] = useState(false); const [profileName, setProfileName] = useState(""); const [profileEmail, setProfileEmail] = useState(""); @@ -138,8 +139,24 @@ export default function SettingsPage() { } catch {} }, [user]); - const copyKey = (id: string) => { - setCopiedKey(id); + const copyToClipboard = async (text: string, id?: string) => { + try { + await navigator.clipboard.writeText(text); + setCopiedKey(id ?? text); + toast.success("Copied to clipboard"); + } catch { + // Fallback for non-HTTPS / older browsers + const ta = document.createElement("textarea"); + ta.value = text; + ta.style.position = "fixed"; + ta.style.opacity = "0"; + document.body.appendChild(ta); + ta.select(); + document.execCommand("copy"); + document.body.removeChild(ta); + setCopiedKey(id ?? text); + toast.success("Copied to clipboard"); + } setTimeout(() => setCopiedKey(null), 2000); }; @@ -149,9 +166,7 @@ export default function SettingsPage() { try { // eslint-disable-next-line @typescript-eslint/no-explicit-any const res = await api.post("/api-keys", { name: genKeyName.trim() }); - toast.success(`Key created: ${res.key}`); - setGenKeyOpen(false); - setGenKeyName(""); + setGeneratedKey(res.key); fetchKeys(); } catch (err) { toast.error(err instanceof Error ? err.message : "Failed to generate key"); @@ -160,6 +175,12 @@ export default function SettingsPage() { } }; + const closeKeyDialog = () => { + setGenKeyOpen(false); + setGenKeyName(""); + setGeneratedKey(null); + }; + const handleDeleteKey = async (id: string) => { try { await api.delete(`/api-keys/${id}`); @@ -319,21 +340,43 @@ export default function SettingsPage() {

API Keys

Manage your API access tokens

- + { if (!open) closeKeyDialog(); else setGenKeyOpen(true); }}> - Generate API Key -
-
- - setGenKeyName(e.target.value)} className="border bg-muted" /> + {generatedKey ? "API Key Created" : "Generate API Key"} + {generatedKey ? ( +
+
+

Copy your key now. You won't be able to see it again.

+
+ + {generatedKey} + + +
+
+
- -
+ ) : ( +
+
+ + setGenKeyName(e.target.value)} className="border bg-muted" /> +
+ +
+ )}
@@ -352,11 +395,6 @@ export default function SettingsPage() {
Last used {k.lastUsed} - - -
diff --git a/web/src/app/visualizations/[id]/page.tsx b/web/src/app/visualizations/[id]/page.tsx index 9b39470..02dbe2f 100644 --- a/web/src/app/visualizations/[id]/page.tsx +++ b/web/src/app/visualizations/[id]/page.tsx @@ -1,13 +1,13 @@ "use client"; -import { useState, useEffect, useCallback } from "react"; +import { useState, useEffect, useCallback, useRef } from "react"; import { useParams, useRouter } from "next/navigation"; import dynamic from "next/dynamic"; import { AppShell } from "@/components/layout/app-shell"; import { AnimatedPage } from "@/components/shared/animated-page"; import { GlassCard } from "@/components/shared/glass-card"; import { ErrorState } from "@/components/shared/error-state"; -import { VizRenderer } from "@/components/shared/viz-renderer"; +import { VizRenderer, downloadVisualization } from "@/components/shared/viz-renderer"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -25,6 +25,7 @@ import { BarChart3, Loader2, Check, + Download, } from "lucide-react"; import { api } from "@/lib/api"; import { toast } from "sonner"; @@ -311,7 +312,9 @@ export default function VisualizationDetailPage() { const [description, setDescription] = useState(""); const [refreshInterval, setRefreshInterval] = useState("0"); const [showPreview, setShowPreview] = useState(true); + const [showCode, setShowCode] = useState(true); const [activeTab, setActiveTab] = useState("code"); + const previewContainerRef = useRef(null); // Live preview state for interactive backends (plotly, vega-lite) const [previewOutput, setPreviewOutput] = useState(null); @@ -541,7 +544,29 @@ export default function VisualizationDetailPage() { variant="outline" size="sm" className="gap-1.5 border text-xs" - onClick={() => setShowPreview(!showPreview)} + onClick={() => { + if (showCode && !showPreview) { + setShowPreview(true); + } + setShowCode(!showCode); + }} + > + + {showCode ? "Hide Code" : "Show Code"} + + + + + + + + + {!viz.published && ( {/* Left: Code Editor */} + {showCode && ( @@ -760,6 +809,7 @@ export default function VisualizationDetailPage() { + )} {/* Right: Preview */} {showPreview && ( @@ -794,7 +844,7 @@ export default function VisualizationDetailPage() { )} -
+
{ + if (blob) triggerDownload(blob, `${filename}.png`); + }); + } + } +} + // ── CDN Script Loading ───────────────────────────────────────────── // // All UMD libraries (Plotly, Vega, Bokeh) have the same problem: diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index ea3493f..51e784e 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -2,10 +2,13 @@ const API_BASE = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8080"; class ApiClient { private token: string | null = null; + private refreshToken: string | null = null; + private refreshing: Promise | null = null; constructor() { if (typeof window !== "undefined") { this.token = localStorage.getItem("auth_token"); + this.refreshToken = localStorage.getItem("refresh_token"); } } @@ -16,10 +19,19 @@ class ApiClient { } } + setRefreshToken(token: string) { + this.refreshToken = token; + if (typeof window !== "undefined") { + localStorage.setItem("refresh_token", token); + } + } + clearToken() { this.token = null; + this.refreshToken = null; if (typeof window !== "undefined") { localStorage.removeItem("auth_token"); + localStorage.removeItem("refresh_token"); } } @@ -27,6 +39,24 @@ class ApiClient { return this.token; } + private async tryRefresh(): Promise { + if (!this.refreshToken) return false; + try { + const res = await fetch(`${API_BASE}/auth/refresh`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ refresh_token: this.refreshToken }), + }); + if (!res.ok) return false; + const data = await res.json(); + this.setToken(data.access_token); + if (data.refresh_token) this.setRefreshToken(data.refresh_token); + return true; + } catch { + return false; + } + } + private async request( path: string, options: RequestInit = {} @@ -40,6 +70,17 @@ class ApiClient { } const res = await fetch(`${API_BASE}${path}`, { ...options, headers }); if (res.status === 401 && !path.startsWith("/auth/")) { + // Try to refresh the token once + if (!this.refreshing) { + this.refreshing = this.tryRefresh().finally(() => { this.refreshing = null; }); + } + const refreshed = await this.refreshing; + if (refreshed) { + // Retry the original request with the new token + headers["Authorization"] = `Bearer ${this.token}`; + const retry = await fetch(`${API_BASE}${path}`, { ...options, headers }); + if (retry.ok) return retry.json(); + } this.clearToken(); if (typeof window !== "undefined") window.location.href = "/login"; throw new Error("Unauthorized"); diff --git a/web/src/lib/auth.ts b/web/src/lib/auth.ts index 01fa312..a3faaa1 100644 --- a/web/src/lib/auth.ts +++ b/web/src/lib/auth.ts @@ -21,6 +21,7 @@ export async function login( ): Promise { const res = await api.post("/auth/login", { email, password }); api.setToken(res.access_token); + api.setRefreshToken(res.refresh_token); return res; } @@ -36,6 +37,7 @@ export async function register(data: { name: data.display_name, }); api.setToken(res.access_token); + api.setRefreshToken(res.refresh_token); return res; }