diff --git a/crates/paimon/Cargo.toml b/crates/paimon/Cargo.toml index d2768b4..d9ecce8 100644 --- a/crates/paimon/Cargo.toml +++ b/crates/paimon/Cargo.toml @@ -58,7 +58,17 @@ arrow-array = { workspace = true } futures = "0.3" parquet = { workspace = true, features = ["async", "zstd"] } async-stream = "0.3.6" -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", features = ["json", "blocking"] } +# DLF authentication dependencies +base64 = "0.22" +hex = "0.4" +hmac = "0.12" +sha1 = "0.10" +sha2 = "0.10" +md-5 = "0.10" +regex = "1" +uuid = { version = "1", features = ["v4"] } +urlencoding = "2.1" [dev-dependencies] axum = { version = "0.7", features = ["macros", "tokio", "http1", "http2"] } diff --git a/crates/paimon/examples/rest_list_databases_example.rs b/crates/paimon/examples/rest_list_databases_example.rs deleted file mode 100644 index 2a5e435..0000000 --- a/crates/paimon/examples/rest_list_databases_example.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Example: List databases using RESTApi -//! -//! This example demonstrates how to create a RESTApi instance -//! and call the list_databases() API to retrieve all databases. -//! -//! # Usage -//! ```bash -//! cargo run -p paimon --example list_databases_example -//! ``` - -use paimon::api::rest_api::RESTApi; -use paimon::common::{CatalogOptions, Options}; - -#[tokio::main] -async fn main() { - // Create configuration options - let mut options = Options::new(); - - // Basic configuration - replace with your actual server URL - options.set(CatalogOptions::METASTORE, "rest"); - options.set(CatalogOptions::WAREHOUSE, "your_warehouse"); - options.set(CatalogOptions::URI, "http://localhost:8080/"); - - // Bearer token authentication (optional) - // options.set(CatalogOptions::TOKEN_PROVIDER, "bear"); - // options.set(CatalogOptions::TOKEN, "your_token"); - - // Create RESTApi instance - // config_required = true means it will fetch config from server - println!("Creating RESTApi instance..."); - let api = match RESTApi::new(options, true).await { - Ok(api) => api, - Err(e) => { - eprintln!("Failed to create RESTApi: {e}"); - return; - } - }; - - // Call list_databases() API - println!("Calling list_databases()..."); - match api.list_databases().await { - Ok(databases) => { - println!("Databases found: {databases:?}"); - println!("Total count: {}", databases.len()); - } - Err(e) => { - eprintln!("Failed to list databases: {e}"); - } - } -} diff --git a/crates/paimon/src/api/api_request.rs b/crates/paimon/src/api/api_request.rs new file mode 100644 index 0000000..33a52fd --- /dev/null +++ b/crates/paimon/src/api/api_request.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! REST API request types for Paimon. +//! +//! This module contains all request structures used in REST API calls. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::{catalog::Identifier, spec::Schema}; + +/// Request to create a new database. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateDatabaseRequest { + /// The name of the database to create. + pub name: String, + /// Optional configuration options for the database. + pub options: HashMap, +} + +impl CreateDatabaseRequest { + /// Create a new CreateDatabaseRequest. + pub fn new(name: String, options: HashMap) -> Self { + Self { name, options } + } +} + +/// Request to alter a database's configuration. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AlterDatabaseRequest { + /// Keys to remove from the database options. + pub removals: Vec, + /// Key-value pairs to update in the database options. + pub updates: HashMap, +} + +impl AlterDatabaseRequest { + /// Create a new AlterDatabaseRequest. + pub fn new(removals: Vec, updates: HashMap) -> Self { + Self { removals, updates } + } +} + +/// Request to rename a table. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RenameTableRequest { + /// The source table identifier. + pub source: Identifier, + /// The destination table identifier. + pub destination: Identifier, +} + +impl RenameTableRequest { + /// Create a new RenameTableRequest. + pub fn new(source: Identifier, destination: Identifier) -> Self { + Self { + source, + destination, + } + } +} + +/// Request to create a new table. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateTableRequest { + /// The identifier for the table to create. + pub identifier: Identifier, + /// The schema definition for the table. + pub schema: Schema, +} + +impl CreateTableRequest { + /// Create a new CreateTableRequest. + pub fn new(identifier: Identifier, schema: Schema) -> Self { + Self { identifier, schema } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_database_request_serialization() { + let mut options = HashMap::new(); + options.insert("key".to_string(), "value".to_string()); + let req = CreateDatabaseRequest::new("test_db".to_string(), options); + + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"name\":\"test_db\"")); + assert!(json.contains("\"options\"")); + } + + #[test] + fn test_alter_database_request_serialization() { + let mut updates = HashMap::new(); + updates.insert("key".to_string(), "new_value".to_string()); + let req = AlterDatabaseRequest::new(vec!["old_key".to_string()], updates); + + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"removals\":[\"old_key\"]")); + assert!(json.contains("\"updates\"")); + } + + #[test] + fn test_rename_table_request_serialization() { + let source = Identifier::new("db1".to_string(), "table1".to_string()); + let destination = Identifier::new("db2".to_string(), "table2".to_string()); + let req = RenameTableRequest::new(source, destination); + + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"source\"")); + assert!(json.contains("\"destination\"")); + } +} diff --git a/crates/paimon/src/api/api_response.rs b/crates/paimon/src/api/api_response.rs index 83296c9..e282080 100644 --- a/crates/paimon/src/api/api_response.rs +++ b/crates/paimon/src/api/api_response.rs @@ -22,8 +22,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -/// Base trait for REST responses. -pub trait RESTResponse {} +use crate::spec::Schema; /// Error response from REST API calls. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -39,7 +38,6 @@ pub struct ErrorResponse { pub code: Option, } -impl RESTResponse for ErrorResponse {} impl ErrorResponse { /// Create a new ErrorResponse. pub fn new( @@ -57,6 +55,141 @@ impl ErrorResponse { } } +/// Base response containing audit information. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AuditRESTResponse { + /// The owner of the resource. + pub owner: Option, + /// Timestamp when the resource was created. + pub created_at: Option, + /// User who created the resource. + pub created_by: Option, + /// Timestamp when the resource was last updated. + pub updated_at: Option, + /// User who last updated the resource. + pub updated_by: Option, +} + +impl AuditRESTResponse { + /// Create a new AuditRESTResponse. + pub fn new( + owner: Option, + created_at: Option, + created_by: Option, + updated_at: Option, + updated_by: Option, + ) -> Self { + Self { + owner, + created_at, + created_by, + updated_at, + updated_by, + } + } + + /// Put audit options into the provided dictionary. + pub fn put_audit_options_to(&self, options: &mut HashMap) { + if let Some(owner) = &self.owner { + options.insert("owner".to_string(), owner.clone()); + } + if let Some(created_by) = &self.created_by { + options.insert("createdBy".to_string(), created_by.clone()); + } + if let Some(created_at) = self.created_at { + options.insert("createdAt".to_string(), created_at.to_string()); + } + if let Some(updated_by) = &self.updated_by { + options.insert("updatedBy".to_string(), updated_by.clone()); + } + if let Some(updated_at) = self.updated_at { + options.insert("updatedAt".to_string(), updated_at.to_string()); + } + } +} + +/// Response for getting a table. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetTableResponse { + /// Audit information. + #[serde(flatten)] + pub audit: AuditRESTResponse, + /// The unique identifier of the table. + pub id: Option, + /// The name of the table. + pub name: Option, + /// The path to the table. + pub path: Option, + /// Whether the table is external. + pub is_external: Option, + /// The schema ID of the table. + pub schema_id: Option, + /// The schema of the table. + pub schema: Option, +} + +impl GetTableResponse { + /// Create a new GetTableResponse. + #[allow(clippy::too_many_arguments)] + pub fn new( + id: Option, + name: Option, + path: Option, + is_external: Option, + schema_id: Option, + schema: Option, + audit: AuditRESTResponse, + ) -> Self { + Self { + audit, + id, + name, + path, + is_external, + schema_id, + schema, + } + } +} + +/// Response for getting a database. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetDatabaseResponse { + /// Audit information. + #[serde(flatten)] + pub audit: AuditRESTResponse, + /// The unique identifier of the database. + pub id: Option, + /// The name of the database. + pub name: Option, + /// The location of the database. + pub location: Option, + /// Configuration options for the database. + pub options: HashMap, +} + +impl GetDatabaseResponse { + /// Create a new GetDatabaseResponse. + pub fn new( + id: Option, + name: Option, + location: Option, + options: HashMap, + audit: AuditRESTResponse, + ) -> Self { + Self { + audit, + id, + name, + location, + options, + } + } +} + /// Response containing configuration defaults. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -65,8 +198,6 @@ pub struct ConfigResponse { pub defaults: HashMap, } -impl RESTResponse for ConfigResponse {} - impl ConfigResponse { /// Create a new ConfigResponse. pub fn new(defaults: HashMap) -> Self { @@ -97,8 +228,6 @@ pub struct ListDatabasesResponse { pub next_page_token: Option, } -impl RESTResponse for ListDatabasesResponse {} - impl ListDatabasesResponse { /// Create a new ListDatabasesResponse. pub fn new(databases: Vec, next_page_token: Option) -> Self { @@ -109,6 +238,26 @@ impl ListDatabasesResponse { } } +/// Response for listing tables. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListTablesResponse { + /// List of table names. + pub tables: Option>, + /// Token for the next page. + pub next_page_token: Option, +} + +impl ListTablesResponse { + /// Create a new ListTablesResponse. + pub fn new(tables: Option>, next_page_token: Option) -> Self { + Self { + tables, + next_page_token, + } + } +} + /// A paginated list of elements with an optional next page token. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -128,7 +277,6 @@ impl PagedList { } } } - #[cfg(test)] mod tests { use super::*; @@ -160,4 +308,24 @@ mod tests { assert!(json.contains("\"databases\":[\"db1\",\"db2\"]")); assert!(json.contains("\"nextPageToken\":\"token123\"")); } + + #[test] + fn test_audit_response_options() { + let audit = AuditRESTResponse::new( + Some("owner1".to_string()), + Some(1000), + Some("creator".to_string()), + Some(2000), + Some("updater".to_string()), + ); + + let mut options = HashMap::new(); + audit.put_audit_options_to(&mut options); + + assert_eq!(options.get("owner"), Some(&"owner1".to_string())); + assert_eq!(options.get("createdBy"), Some(&"creator".to_string())); + assert_eq!(options.get("createdAt"), Some(&"1000".to_string())); + assert_eq!(options.get("updatedBy"), Some(&"updater".to_string())); + assert_eq!(options.get("updatedAt"), Some(&"2000".to_string())); + } } diff --git a/crates/paimon/src/api/auth/base.rs b/crates/paimon/src/api/auth/base.rs index 77e9f76..7b0ae33 100644 --- a/crates/paimon/src/api/auth/base.rs +++ b/crates/paimon/src/api/auth/base.rs @@ -19,6 +19,10 @@ use std::collections::HashMap; +use async_trait::async_trait; + +use crate::Result; + /// Parameter for REST authentication. /// /// Contains information about the request being authenticated. @@ -70,7 +74,8 @@ impl RESTAuthParameter { /// /// Implement this trait to provide custom authentication mechanisms /// for REST API requests. -pub trait AuthProvider { +#[async_trait] +pub trait AuthProvider: Send { /// Merge authentication headers into the base headers. /// /// # Arguments @@ -78,13 +83,14 @@ pub trait AuthProvider { /// * `parameter` - Information about the request being authenticated /// /// # Returns - fn merge_auth_header( - &self, + async fn merge_auth_header( + &mut self, base_header: HashMap, parameter: &RESTAuthParameter, - ) -> HashMap; + ) -> Result>; } - +/// Authorization header key. +pub const AUTHORIZATION_HEADER_KEY: &str = "Authorization"; /// Function wrapper for REST authentication. /// /// This struct combines an initial set of headers with an authentication provider @@ -114,8 +120,12 @@ impl RESTAuthFunction { /// /// # Returns /// A HashMap containing the authenticated headers. - pub fn apply(&self, parameter: &RESTAuthParameter) -> HashMap { + pub async fn apply( + &mut self, + parameter: &RESTAuthParameter, + ) -> Result> { self.auth_provider .merge_auth_header(self.init_header.clone(), parameter) + .await } } diff --git a/crates/paimon/src/api/auth/bear_provider.rs b/crates/paimon/src/api/auth/bearer_provider.rs similarity index 77% rename from crates/paimon/src/api/auth/bear_provider.rs rename to crates/paimon/src/api/auth/bearer_provider.rs index 96bdfe9..2b9f0e1 100644 --- a/crates/paimon/src/api/auth/bear_provider.rs +++ b/crates/paimon/src/api/auth/bearer_provider.rs @@ -19,6 +19,8 @@ use std::collections::HashMap; +use async_trait::async_trait; + use super::base::{AuthProvider, RESTAuthParameter}; /// Authentication provider using Bearer token. @@ -41,17 +43,18 @@ impl BearerTokenAuthProvider { } } +#[async_trait] impl AuthProvider for BearerTokenAuthProvider { - fn merge_auth_header( - &self, + async fn merge_auth_header( + &mut self, mut base_header: HashMap, _parameter: &RESTAuthParameter, - ) -> HashMap { + ) -> crate::Result> { base_header.insert( "Authorization".to_string(), format!("Bearer {}", self.token), ); - base_header + Ok(base_header) } } @@ -59,13 +62,16 @@ impl AuthProvider for BearerTokenAuthProvider { mod tests { use super::*; - #[test] - fn test_bearer_token_auth() { - let provider = BearerTokenAuthProvider::new("test-token"); + #[tokio::test] + async fn test_bearer_token_auth() { + let mut provider = BearerTokenAuthProvider::new("test-token"); let base_header = HashMap::new(); let parameter = RESTAuthParameter::for_get("/test", HashMap::new()); - let headers = provider.merge_auth_header(base_header, ¶meter); + let headers = provider + .merge_auth_header(base_header, ¶meter) + .await + .unwrap(); assert_eq!( headers.get("Authorization"), @@ -73,14 +79,17 @@ mod tests { ); } - #[test] - fn test_bearer_token_with_base_headers() { - let provider = BearerTokenAuthProvider::new("my-token"); + #[tokio::test] + async fn test_bearer_token_with_base_headers() { + let mut provider = BearerTokenAuthProvider::new("my-token"); let mut base_header = HashMap::new(); base_header.insert("Content-Type".to_string(), "application/json".to_string()); let parameter = RESTAuthParameter::for_get("/test", HashMap::new()); - let headers = provider.merge_auth_header(base_header, ¶meter); + let headers = provider + .merge_auth_header(base_header, ¶meter) + .await + .unwrap(); assert_eq!( headers.get("Authorization"), diff --git a/crates/paimon/src/api/auth/dlf_provider.rs b/crates/paimon/src/api/auth/dlf_provider.rs new file mode 100644 index 0000000..8655bd2 --- /dev/null +++ b/crates/paimon/src/api/auth/dlf_provider.rs @@ -0,0 +1,474 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DLF Authentication Provider for Alibaba Cloud Data Lake Formation. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use chrono::Utc; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +use super::base::{AuthProvider, RESTAuthParameter, AUTHORIZATION_HEADER_KEY}; +use super::dlf_signer::{DLFRequestSigner, DLFSignerFactory}; +use crate::common::{CatalogOptions, Options}; +use crate::error::Error; +use crate::Result; + +// ============================================================================ +// DLF Token and Token Loader +// ============================================================================ + +/// DLF Token containing access credentials for Alibaba Cloud Data Lake Formation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DLFToken { + /// Access key ID for Alibaba Cloud. + #[serde(rename = "AccessKeyId")] + pub access_key_id: String, + /// Access key secret for Alibaba Cloud. + #[serde(rename = "AccessKeySecret")] + pub access_key_secret: String, + /// Security token for temporary credentials (optional). + #[serde(rename = "SecurityToken")] + pub security_token: Option, + /// Expiration time string (ISO 8601 format). + #[serde( + rename = "Expiration", + default, + skip_serializing_if = "Option::is_none" + )] + pub expiration: Option, + /// Expiration timestamp in milliseconds. + #[serde(rename = "ExpirationAt", default)] + pub expiration_at_millis: Option, +} + +impl DLFToken { + /// Token date format for parsing expiration. + const TOKEN_DATE_FORMAT: &'static str = "%Y-%m-%dT%H:%M:%SZ"; + + /// Create a new DLFToken. + /// + /// # Arguments + /// * `access_key_id` - The access key ID + /// * `access_key_secret` - The access key secret + /// * `security_token` - Optional security token + /// * `expiration` - Optional expiration time string (ISO 8601 format) + /// * `expiration_at_millis` - Optional expiration timestamp in milliseconds. + /// If provided, this value is used directly. Otherwise, it will be parsed from `expiration`. + pub fn new( + access_key_id: impl Into, + access_key_secret: impl Into, + security_token: Option, + expiration_at_millis: Option, + expiration: Option, + ) -> Self { + let access_key_id = access_key_id.into(); + let access_key_secret = access_key_secret.into(); + + // Use provided expiration_at_millis, or parse from expiration string if not provided + let expiration_at_millis = expiration_at_millis.or_else(|| { + expiration + .as_ref() + .and_then(|exp| Self::parse_expiration_to_millis(exp)) + }); + + Self { + access_key_id, + access_key_secret, + security_token, + expiration_at_millis, + expiration, + } + } + + /// Create a DLFToken from configuration options. + pub fn from_options(options: &Options) -> Option { + let access_key_id = options.get(CatalogOptions::DLF_ACCESS_KEY_ID)?; + let access_key_secret = options.get(CatalogOptions::DLF_ACCESS_KEY_SECRET)?; + let security_token = options + .get(CatalogOptions::DLF_ACCESS_SECURITY_TOKEN) + .cloned(); + + Some(Self::new( + access_key_id.clone(), + access_key_secret.clone(), + security_token, + None, + None, + )) + } + + /// Parse expiration string to milliseconds timestamp. + pub fn parse_expiration_to_millis(expiration: &str) -> Option { + let datetime = chrono::NaiveDateTime::parse_from_str(expiration, Self::TOKEN_DATE_FORMAT) + .ok()? + .and_utc(); + Some(datetime.timestamp_millis()) + } +} +/// Trait for DLF token loaders. +#[async_trait] +pub trait DLFTokenLoader: Send + Sync { + /// Load a DLF token. + async fn load_token(&self) -> Result; + + /// Get a description of the loader. + fn description(&self) -> &str; +} + +/// DLF ECS Token Loader. +/// +/// Loads DLF tokens from ECS metadata service. +/// +/// This implementation mirrors the Python DLFECSTokenLoader class, +/// using class-level HTTP client for connection reuse and retry logic. +pub struct DLFECSTokenLoader { + ecs_metadata_url: String, + role_name: Option, + http_client: TokenHTTPClient, +} + +impl DLFECSTokenLoader { + /// Create a new DLFECSTokenLoader. + /// + /// # Arguments + /// * `ecs_metadata_url` - ECS metadata service URL + /// * `role_name` - Optional role name. If None, will be fetched from metadata service + pub fn new(ecs_metadata_url: impl Into, role_name: Option) -> Self { + Self { + ecs_metadata_url: ecs_metadata_url.into(), + role_name, + http_client: TokenHTTPClient::new(), + } + } + + /// Get the role name from ECS metadata service. + async fn get_role(&self) -> Result { + self.http_client.get(&self.ecs_metadata_url).await + } + + /// Get the token from ECS metadata service. + async fn get_token(&self, url: &str) -> Result { + let token_json = self.http_client.get(url).await?; + serde_json::from_str(&token_json).map_err(|e| Error::DataInvalid { + message: format!("Failed to parse token JSON: {}", e), + source: None, + }) + } + + /// Build the token URL from base URL and role name. + fn build_token_url(&self, role_name: &str) -> String { + let base_url = self.ecs_metadata_url.trim_end_matches('/'); + format!("{}/{}", base_url, role_name) + } +} + +#[async_trait] +impl DLFTokenLoader for DLFECSTokenLoader { + async fn load_token(&self) -> Result { + let role_name = match &self.role_name { + Some(name) => name.clone(), + None => { + // Fetch role name from metadata service + self.get_role().await? + } + }; + + // Build token URL + let token_url = self.build_token_url(&role_name); + + // Get token + self.get_token(&token_url).await + } + + fn description(&self) -> &str { + &self.ecs_metadata_url + } +} +/// Factory for creating DLF token loaders. +pub struct DLFTokenLoaderFactory; + +impl DLFTokenLoaderFactory { + /// Create a token loader based on options. + pub fn create_token_loader(options: &Options) -> Option> { + let loader = options.get(CatalogOptions::DLF_TOKEN_LOADER)?; + + if loader == "ecs" { + let ecs_metadata_url = options + .get(CatalogOptions::DLF_TOKEN_ECS_METADATA_URL) + .cloned() + .unwrap_or_else(|| { + "http://100.100.100.200/latest/meta-data/Ram/security-credentials/".to_string() + }); + let role_name = options + .get(CatalogOptions::DLF_TOKEN_ECS_ROLE_NAME) + .cloned(); + Some( + Arc::new(DLFECSTokenLoader::new(ecs_metadata_url, role_name)) + as Arc, + ) + } else { + None + } + } +} +// ============================================================================ +// DLF Auth Provider +// ============================================================================ + +/// Token expiration safe time in milliseconds (1 hour). +/// Token will be refreshed if it expires within this time. +const TOKEN_EXPIRATION_SAFE_TIME_MILLIS: i64 = 3_600_000; + +/// DLF Authentication Provider for Alibaba Cloud Data Lake Formation. +/// +/// This provider implements authentication for Alibaba Cloud DLF service, +/// supporting both VPC endpoints (DLF4-HMAC-SHA256) and public endpoints +/// (ROA v2 HMAC-SHA1). +pub struct DLFAuthProvider { + uri: String, + token: Option, + token_loader: Option>, + signer: Box, +} + +impl DLFAuthProvider { + /// Create a new DLFAuthProvider. + /// + /// # Arguments + /// * `uri` - The DLF service URI + /// * `token` - Optional DLF token containing access credentials + /// * `token_loader` - Optional token loader for dynamic token retrieval + /// + /// # Errors + /// Returns an error if both `token` and `token_loader` are `None`. + pub fn new( + uri: impl Into, + region: impl Into, + signing_algorithm: impl Into, + token: Option, + token_loader: Option>, + ) -> Result { + if token.is_none() && token_loader.is_none() { + return Err(Error::ConfigInvalid { + message: "Either token or token_loader must be provided".to_string(), + }); + } + + let uri = uri.into(); + let region = region.into(); + let signing_algorithm = signing_algorithm.into(); + let signer = DLFSignerFactory::create_signer(&signing_algorithm, ®ion); + + Ok(Self { + uri, + token, + token_loader, + signer, + }) + } + + /// Get or refresh the token. + /// + /// If token_loader is configured, this method will: + /// - Load a new token if current token is None + /// - Refresh the token if it's about to expire (within TOKEN_EXPIRATION_SAFE_TIME_MILLIS) + async fn get_or_refresh_token(&mut self) -> Result { + if let Some(loader) = &self.token_loader { + let need_reload = match &self.token { + None => true, + Some(token) => match token.expiration_at_millis { + Some(expiration_at_millis) => { + let now = chrono::Utc::now().timestamp_millis(); + expiration_at_millis - now < TOKEN_EXPIRATION_SAFE_TIME_MILLIS + } + None => false, + }, + }; + + if need_reload { + let new_token = loader.load_token().await?; + self.token = Some(new_token); + } + } + + self.token.clone().ok_or_else(|| Error::DataInvalid { + message: "Either token or token_loader must be provided".to_string(), + source: None, + }) + } + + /// Extract host from URI. + fn extract_host(uri: &str) -> String { + let without_protocol = uri + .strip_prefix("https://") + .or_else(|| uri.strip_prefix("http://")) + .unwrap_or(uri); + + let path_index = without_protocol.find('/').unwrap_or(without_protocol.len()); + without_protocol[..path_index].to_string() + } +} + +#[async_trait] +impl AuthProvider for DLFAuthProvider { + async fn merge_auth_header( + &mut self, + mut base_header: HashMap, + rest_auth_parameter: &RESTAuthParameter, + ) -> crate::Result> { + // Get token (will auto-refresh if needed via token_loader) + let token = self.get_or_refresh_token().await?; + + let now = Utc::now(); + let host = Self::extract_host(&self.uri); + + // Generate signature headers + let sign_headers = self.signer.sign_headers( + rest_auth_parameter.data.as_deref(), + &now, + token.security_token.as_deref(), + &host, + ); + + // Generate authorization header + let authorization = + self.signer + .authorization(rest_auth_parameter, &token, &host, &sign_headers); + + // Merge all headers + base_header.extend(sign_headers); + base_header.insert(AUTHORIZATION_HEADER_KEY.to_string(), authorization); + + Ok(base_header) + } +} + +// ============================================================================ +// DLF Token Loader Implementation +// ============================================================================ + +/// HTTP client for token loading with retry and timeout configuration. +struct TokenHTTPClient { + max_retries: u32, + client: Client, +} + +impl TokenHTTPClient { + /// Create a new HTTP client with default settings. + fn new() -> Self { + let connect_timeout = std::time::Duration::from_secs(180); // 3 minutes + let read_timeout = std::time::Duration::from_secs(180); // 3 minutes + + let client = Client::builder() + .timeout(read_timeout) + .connect_timeout(connect_timeout) + .build() + .expect("Failed to create HTTP client"); + + Self { + max_retries: 3, + client, + } + } + + /// Perform HTTP GET request with retry logic. + async fn get(&self, url: &str) -> Result { + let mut last_error = String::new(); + for attempt in 0..self.max_retries { + match self.client.get(url).send().await { + Ok(response) if response.status().is_success() => { + return response.text().await.map_err(|e| Error::DataInvalid { + message: format!("Failed to read response: {}", e), + source: None, + }); + } + Ok(response) => { + last_error = format!("HTTP error: {}", response.status()); + } + Err(e) => { + last_error = format!("Request failed: {}", e); + } + } + + if attempt < self.max_retries - 1 { + // Exponential backoff + let delay = std::time::Duration::from_millis(100 * 2u64.pow(attempt)); + tokio::time::sleep(delay).await; + } + } + + Err(Error::DataInvalid { + message: last_error, + source: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_host() { + let uri = "http://dlf-abcdfgerrf.net/api/v1"; + let host = DLFAuthProvider::extract_host(uri); + assert_eq!(host, "dlf-abcdfgerrf.net"); + } + + #[test] + fn test_extract_host_no_path() { + let uri = "https://dlf.cn-abcdfgerrf.aliyuncs.com"; + let host = DLFAuthProvider::extract_host(uri); + assert_eq!(host, "dlf.cn-abcdfgerrf.aliyuncs.com"); + } + + #[test] + fn test_dlf_token_from_options() { + let mut options = Options::new(); + options.set(CatalogOptions::DLF_ACCESS_KEY_ID, "test_key_id"); + options.set(CatalogOptions::DLF_ACCESS_KEY_SECRET, "test_key_secret"); + options.set( + CatalogOptions::DLF_ACCESS_SECURITY_TOKEN, + "test_security_token", + ); + + let token = DLFToken::from_options(&options).unwrap(); + assert_eq!(token.access_key_id, "test_key_id"); + assert_eq!(token.access_key_secret, "test_key_secret"); + assert_eq!( + token.security_token, + Some("test_security_token".to_string()) + ); + } + + #[test] + fn test_dlf_token_missing_credentials() { + let options = Options::new(); + assert!(DLFToken::from_options(&options).is_none()); + } + + #[test] + fn test_parse_expiration() { + let expiration = "2024-12-31T23:59:59Z"; + let millis = DLFToken::parse_expiration_to_millis(expiration); + assert!(millis.is_some()); + } +} diff --git a/crates/paimon/src/api/auth/dlf_signer.rs b/crates/paimon/src/api/auth/dlf_signer.rs new file mode 100644 index 0000000..a97aac8 --- /dev/null +++ b/crates/paimon/src/api/auth/dlf_signer.rs @@ -0,0 +1,665 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DLF Request Signer implementations for Alibaba Cloud Data Lake Formation. +//! +//! This module provides two signature algorithms for authenticating requests +//! to Alibaba Cloud Data Lake Formation (DLF) service: +//! +//! # Signature Algorithms +//! +//! ## 1. DLF4-HMAC-SHA256 (Default Signer) +//! +//! Used for VPC endpoints (e.g., `*-vpc.dlf.aliyuncs.com`). +//! +//! **Algorithm Overview:** +//! 1. Build a canonical request string from HTTP method, path, query params, and headers +//! 2. Create a string-to-sign with algorithm, timestamp, credential scope, and hashed canonical request +//! 3. Derive a signing key through multiple HMAC-SHA256 operations +//! 4. Calculate the signature and construct the Authorization header +//! +//! **Signing Key Derivation:** +//! ```text +//! kSecret = "aliyun_v4" + AccessKeySecret +//! kDate = HMAC-SHA256(kSecret, Date) +//! kRegion = HMAC-SHA256(kDate, Region) +//! kService = HMAC-SHA256(kRegion, "DlfNext") +//! kSigning = HMAC-SHA256(kService, "aliyun_v4_request") +//! ``` +//! +//! **Authorization Header Format:** +//! ```text +//! DLF4-HMAC-SHA256 Credential=AccessKeyId/Date/Region/DlfNext/aliyun_v4_request,Signature=hex_signature +//! ``` +//! +//! ## 2. HMAC-SHA1 (OpenAPI Signer) +//! +//! Used for public network endpoints (e.g., `dlfnext.*.aliyuncs.com`). +//! Follows Alibaba Cloud ROA v2 signature style. +//! +//! **Algorithm Overview:** +//! 1. Build canonicalized headers from x-acs-* headers +//! 2. Build canonicalized resource from path and query params +//! 3. Create string-to-sign with method, headers, and resource +//! 4. Calculate signature using HMAC-SHA1 +//! +//! **Authorization Header Format:** +//! ```text +//! acs AccessKeyId:base64_signature +//! ``` +//! +//! # References +//! +//! - [Alibaba Cloud API Signature](https://help.aliyun.com/document_detail/315526.html) +//! - [DLF OpenAPI](https://help.aliyun.com/document_detail/197826.html) +//! +//! # Usage +//! +//! The signer is automatically selected based on the endpoint URI: +//! - VPC endpoints → `DLFDefaultSigner` (DLF4-HMAC-SHA256) +//! - Public endpoints with "dlfnext" or "openapi" → `DLFOpenApiSigner` (HMAC-SHA1) + +use std::collections::HashMap; + +use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine}; +use chrono::{DateTime, Utc}; +use hmac::{Hmac, Mac}; +use md5::Md5; +use sha1::Sha1; +use sha2::{Digest, Sha256}; +use uuid::Uuid; + +use super::base::RESTAuthParameter; +use super::dlf_provider::DLFToken; + +type HmacSha256 = Hmac; +type HmacSha1 = Hmac; + +/// Trait for DLF request signers. +/// +/// Different signers implement different signature algorithms for +/// authenticating requests to Alibaba Cloud DLF service. +/// +/// # Implementations +/// +/// - [`DLFDefaultSigner`]: Uses DLF4-HMAC-SHA256 for VPC endpoints +/// - [`DLFOpenApiSigner`]: Uses HMAC-SHA1 for public endpoints +pub trait DLFRequestSigner: Send + Sync { + /// Generate signature headers for the request. + fn sign_headers( + &self, + body: Option<&str>, + now: &DateTime, + security_token: Option<&str>, + host: &str, + ) -> HashMap; + + /// Generate the Authorization header value. + fn authorization( + &self, + rest_auth_parameter: &RESTAuthParameter, + token: &DLFToken, + host: &str, + sign_headers: &HashMap, + ) -> String; + #[allow(dead_code)] + /// Get the identifier for this signer. + fn identifier(&self) -> &str; +} + +/// Default DLF signer using DLF4-HMAC-SHA256 algorithm. +/// +/// This signer is used for VPC endpoints (e.g., `cn-hangzhou-vpc.dlf.aliyuncs.com`). +/// +/// # Algorithm Details +/// +/// The DLF4-HMAC-SHA256 algorithm is similar to AWS Signature Version 4: +/// +/// 1. **Canonical Request**: Combine HTTP method, URI, query string, headers, and payload hash +/// 2. **String-to-Sign**: Include algorithm, timestamp, credential scope, and canonical request hash +/// 3. **Signing Key**: Derived through chained HMAC operations +/// 4. **Signature**: HMAC-SHA256 of string-to-sign using the signing key +/// +/// # Required Headers +/// +/// The following headers are included in the signature calculation: +/// - `content-md5`: MD5 hash of request body (if present) +/// - `content-type`: Media type (if body present) +/// - `x-dlf-content-sha256`: Always "UNSIGNED-PAYLOAD" +/// - `x-dlf-date`: Request timestamp in format `%Y%m%dT%H%M%SZ` +/// - `x-dlf-version`: API version ("v1") +/// - `x-dlf-security-token`: Security token for temporary credentials (optional) +/// +/// # Example +/// +/// ```ignore +/// use paimon::api::auth::{DLFDefaultSigner, DLFRequestSigner}; +/// +/// let signer = DLFDefaultSigner::new("cn-hangzhou"); +/// let headers = signer.sign_headers(Some(r#"{"key":"value"}"#), &Utc::now(), None, "dlf.aliyuncs.com"); +/// ``` +pub struct DLFDefaultSigner { + region: String, +} + +impl DLFDefaultSigner { + pub const IDENTIFIER: &'static str = "default"; + const VERSION: &'static str = "v1"; + const SIGNATURE_ALGORITHM: &'static str = "DLF4-HMAC-SHA256"; + const PRODUCT: &'static str = "DlfNext"; + const REQUEST_TYPE: &'static str = "aliyun_v4_request"; + const SIGNATURE_KEY: &'static str = "Signature"; + const NEW_LINE: &'static str = "\n"; + + // Header keys + const DLF_CONTENT_MD5_HEADER_KEY: &'static str = "Content-MD5"; + const DLF_CONTENT_TYPE_KEY: &'static str = "Content-Type"; + const DLF_DATE_HEADER_KEY: &'static str = "x-dlf-date"; + const DLF_SECURITY_TOKEN_HEADER_KEY: &'static str = "x-dlf-security-token"; + const DLF_AUTH_VERSION_HEADER_KEY: &'static str = "x-dlf-version"; + const DLF_CONTENT_SHA256_HEADER_KEY: &'static str = "x-dlf-content-sha256"; + const DLF_CONTENT_SHA256_VALUE: &'static str = "UNSIGNED-PAYLOAD"; + + const AUTH_DATE_TIME_FORMAT: &'static str = "%Y%m%dT%H%M%SZ"; + const MEDIA_TYPE: &'static str = "application/json"; + + const SIGNED_HEADERS: &'static [&'static str] = &[ + "content-md5", + "content-type", + "x-dlf-content-sha256", + "x-dlf-date", + "x-dlf-version", + "x-dlf-security-token", + ]; + + /// Create a new DLFDefaultSigner with the given region. + pub fn new(region: impl Into) -> Self { + Self { + region: region.into(), + } + } + + fn md5_base64(raw: &str) -> String { + let mut hasher = Md5::new(); + hasher.update(raw.as_bytes()); + let hash = hasher.finalize(); + BASE64_STANDARD.encode(hash) + } + + fn hmac_sha256(key: &[u8], data: &str) -> Vec { + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size"); + mac.update(data.as_bytes()); + mac.finalize().into_bytes().to_vec() + } + + fn sha256_hex(raw: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(raw.as_bytes()); + hex::encode(hasher.finalize()) + } + + fn hex_encode(raw: &[u8]) -> String { + hex::encode(raw) + } + + fn trim(value: &str) -> &str { + value.trim() + } + + fn get_canonical_request( + &self, + rest_auth_parameter: &RESTAuthParameter, + headers: &HashMap, + ) -> String { + let mut parts = vec![ + rest_auth_parameter.method.clone(), + rest_auth_parameter.path.clone(), + ]; + + let canonical_query_string = + self.build_canonical_query_string(&rest_auth_parameter.parameters); + parts.push(canonical_query_string); + + let sorted_headers = self.build_sorted_signed_headers_map(headers); + for (key, value) in sorted_headers { + parts.push(format!("{}:{}", key, value)); + } + + let content_sha256 = headers + .get(Self::DLF_CONTENT_SHA256_HEADER_KEY) + .map(|s| s.as_str()) + .unwrap_or(Self::DLF_CONTENT_SHA256_VALUE); + parts.push(content_sha256.to_string()); + + parts.join(Self::NEW_LINE) + } + + fn build_canonical_query_string(&self, parameters: &HashMap) -> String { + if parameters.is_empty() { + return String::new(); + } + + let mut sorted_params: Vec<_> = parameters.iter().collect(); + sorted_params.sort_by(|a, b| a.0.cmp(b.0)); + + let query_parts: Vec = sorted_params + .iter() + .map(|(key, value)| { + let key = Self::trim(key); + if !value.is_empty() { + let value = Self::trim(value); + format!("{}={}", key, value) + } else { + key.to_string() + } + }) + .collect(); + + query_parts.join("&") + } + + fn build_sorted_signed_headers_map( + &self, + headers: &HashMap, + ) -> Vec<(String, String)> { + let mut sorted_headers: Vec<(String, String)> = headers + .iter() + .filter(|(key, _)| { + let lower_key = key.to_lowercase(); + Self::SIGNED_HEADERS.contains(&lower_key.as_str()) + }) + .map(|(key, value)| (key.to_lowercase(), Self::trim(value).to_string())) + .collect(); + + sorted_headers.sort_by(|a, b| a.0.cmp(&b.0)); + sorted_headers + } +} + +impl DLFRequestSigner for DLFDefaultSigner { + fn sign_headers( + &self, + body: Option<&str>, + now: &DateTime, + security_token: Option<&str>, + _host: &str, + ) -> HashMap { + let mut sign_headers = HashMap::new(); + + let date_time = now.format(Self::AUTH_DATE_TIME_FORMAT).to_string(); + sign_headers.insert(Self::DLF_DATE_HEADER_KEY.to_string(), date_time); + sign_headers.insert( + Self::DLF_CONTENT_SHA256_HEADER_KEY.to_string(), + Self::DLF_CONTENT_SHA256_VALUE.to_string(), + ); + sign_headers.insert( + Self::DLF_AUTH_VERSION_HEADER_KEY.to_string(), + Self::VERSION.to_string(), + ); + + if let Some(body_content) = body { + if !body_content.is_empty() { + sign_headers.insert( + Self::DLF_CONTENT_TYPE_KEY.to_string(), + Self::MEDIA_TYPE.to_string(), + ); + sign_headers.insert( + Self::DLF_CONTENT_MD5_HEADER_KEY.to_string(), + Self::md5_base64(body_content), + ); + } + } + + if let Some(token) = security_token { + sign_headers.insert( + Self::DLF_SECURITY_TOKEN_HEADER_KEY.to_string(), + token.to_string(), + ); + } + + sign_headers + } + + fn authorization( + &self, + rest_auth_parameter: &RESTAuthParameter, + token: &DLFToken, + _host: &str, + sign_headers: &HashMap, + ) -> String { + let date_time = sign_headers.get(Self::DLF_DATE_HEADER_KEY).unwrap(); + let date = &date_time[..8]; + + let canonical_request = self.get_canonical_request(rest_auth_parameter, sign_headers); + + let string_to_sign = [ + Self::SIGNATURE_ALGORITHM.to_string(), + date_time.clone(), + format!( + "{}/{}/{}/{}", + date, + self.region, + Self::PRODUCT, + Self::REQUEST_TYPE + ), + Self::sha256_hex(&canonical_request), + ] + .join(Self::NEW_LINE); + + // Derive signing key + let date_key = Self::hmac_sha256( + format!("aliyun_v4{}", token.access_key_secret).as_bytes(), + date, + ); + let date_region_key = Self::hmac_sha256(&date_key, &self.region); + let date_region_service_key = Self::hmac_sha256(&date_region_key, Self::PRODUCT); + let signing_key = Self::hmac_sha256(&date_region_service_key, Self::REQUEST_TYPE); + + let signature_bytes = Self::hmac_sha256(&signing_key, &string_to_sign); + let signature = Self::hex_encode(&signature_bytes); + + format!( + "{} Credential={}/{}/{}/{}/{},{}={}", + Self::SIGNATURE_ALGORITHM, + token.access_key_id, + date, + self.region, + Self::PRODUCT, + Self::REQUEST_TYPE, + Self::SIGNATURE_KEY, + signature + ) + } + + fn identifier(&self) -> &str { + Self::IDENTIFIER + } +} + +/// DLF OpenAPI signer using HMAC-SHA1 algorithm. +/// +/// This signer follows the Alibaba Cloud ROA v2 signature style and is used +/// for public network endpoints (e.g., `dlfnext.cn-asdnbhwf.aliyuncs.com`). +/// +/// # Algorithm Details +/// +/// The HMAC-SHA1 algorithm is the traditional Alibaba Cloud API signature method: +/// +/// 1. **Canonicalized Headers**: Sort and format all `x-acs-*` headers +/// 2. **Canonicalized Resource**: URL-decoded path with sorted query parameters +/// 3. **String-to-Sign**: Combine HTTP method, standard headers, canonicalized headers, and resource +/// 4. **Signature**: Base64-encoded HMAC-SHA1 of string-to-sign +/// +/// # Required Headers +/// +/// The following headers are included in requests: +/// - `Date`: Request timestamp in RFC 1123 format +/// - `Accept`: Always "application/json" +/// - `Content-MD5`: MD5 hash of request body (if present) +/// - `Content-Type`: Always "application/json" (if body present) +/// - `Host`: Endpoint host +/// - `x-acs-signature-method`: Always "HMAC-SHA1" +/// - `x-acs-signature-nonce`: Unique UUID for each request +/// - `x-acs-signature-version`: Always "1.0" +/// - `x-acs-version`: API version ("2026-01-18") +/// - `x-acs-security-token`: Security token for temporary credentials (optional) +/// +/// # Example +/// +/// ```ignore +/// use paimon::api::auth::{DLFOpenApiSigner, DLFRequestSigner}; +/// +/// let signer = DLFOpenApiSigner; +/// let headers = signer.sign_headers(Some(r#"{"key":"value"}"#), &Utc::now(), None, "dlfnext.aliyuncs.com"); +/// ``` +pub struct DLFOpenApiSigner; + +impl DLFOpenApiSigner { + pub const IDENTIFIER: &'static str = "openapi"; + + // Header constants + const DATE_HEADER: &'static str = "Date"; + const ACCEPT_HEADER: &'static str = "Accept"; + const CONTENT_MD5_HEADER: &'static str = "Content-MD5"; + const CONTENT_TYPE_HEADER: &'static str = "Content-Type"; + const HOST_HEADER: &'static str = "Host"; + const X_ACS_SIGNATURE_METHOD: &'static str = "x-acs-signature-method"; + const X_ACS_SIGNATURE_NONCE: &'static str = "x-acs-signature-nonce"; + const X_ACS_SIGNATURE_VERSION: &'static str = "x-acs-signature-version"; + const X_ACS_VERSION: &'static str = "x-acs-version"; + const X_ACS_SECURITY_TOKEN: &'static str = "x-acs-security-token"; + + // Values + const DATE_FORMAT: &'static str = "%a, %d %b %Y %H:%M:%S GMT"; + const ACCEPT_VALUE: &'static str = "application/json"; + const CONTENT_TYPE_VALUE: &'static str = "application/json"; + const SIGNATURE_METHOD_VALUE: &'static str = "HMAC-SHA1"; + const SIGNATURE_VERSION_VALUE: &'static str = "1.0"; + const API_VERSION: &'static str = "2026-01-18"; + + fn md5_base64(data: &str) -> String { + let mut hasher = Md5::new(); + hasher.update(data.as_bytes()); + let hash = hasher.finalize(); + BASE64_STANDARD.encode(hash) + } + + fn hmac_sha1_base64(key: &str, data: &str) -> String { + let mut mac = + HmacSha1::new_from_slice(key.as_bytes()).expect("HMAC can take key of any size"); + mac.update(data.as_bytes()); + BASE64_STANDARD.encode(mac.finalize().into_bytes()) + } + + fn trim(value: &str) -> &str { + value.trim() + } + + fn build_canonicalized_headers(&self, headers: &HashMap) -> String { + let mut sorted_headers: Vec<(String, String)> = headers + .iter() + .filter(|(key, _)| key.to_lowercase().starts_with("x-acs-")) + .map(|(key, value)| (key.to_lowercase(), Self::trim(value).to_string())) + .collect(); + + sorted_headers.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut result = String::new(); + for (key, value) in sorted_headers { + result.push_str(&format!("{}:{}\n", key, value)); + } + result + } + + fn build_canonicalized_resource(&self, rest_auth_parameter: &RESTAuthParameter) -> String { + let path = urlencoding::decode(&rest_auth_parameter.path).unwrap_or_default(); + + if rest_auth_parameter.parameters.is_empty() { + return path.to_string(); + } + + let mut sorted_params: Vec<_> = rest_auth_parameter.parameters.iter().collect(); + sorted_params.sort_by(|a, b| a.0.cmp(b.0)); + + let query_parts: Vec = sorted_params + .iter() + .map(|(key, value)| { + let decoded_value = urlencoding::decode(value).unwrap_or_default(); + if !decoded_value.is_empty() { + format!("{}={}", key, decoded_value) + } else { + key.to_string() + } + }) + .collect(); + + format!("{}?{}", path, query_parts.join("&")) + } + + fn build_string_to_sign( + &self, + rest_auth_parameter: &RESTAuthParameter, + headers: &HashMap, + canonicalized_headers: &str, + canonicalized_resource: &str, + ) -> String { + let parts = [ + rest_auth_parameter.method.clone(), + headers + .get(Self::ACCEPT_HEADER) + .cloned() + .unwrap_or_default(), + headers + .get(Self::CONTENT_MD5_HEADER) + .cloned() + .unwrap_or_default(), + headers + .get(Self::CONTENT_TYPE_HEADER) + .cloned() + .unwrap_or_default(), + headers.get(Self::DATE_HEADER).cloned().unwrap_or_default(), + canonicalized_headers.to_string(), + ]; + + parts.join("\n") + canonicalized_resource + } +} + +impl DLFRequestSigner for DLFOpenApiSigner { + fn sign_headers( + &self, + body: Option<&str>, + now: &DateTime, + security_token: Option<&str>, + host: &str, + ) -> HashMap { + let mut headers = HashMap::new(); + + // Date header in RFC 1123 format + headers.insert( + Self::DATE_HEADER.to_string(), + now.format(Self::DATE_FORMAT).to_string(), + ); + + // Accept header + headers.insert( + Self::ACCEPT_HEADER.to_string(), + Self::ACCEPT_VALUE.to_string(), + ); + + // Content-MD5 and Content-Type (if body exists) + if let Some(body_content) = body { + if !body_content.is_empty() { + headers.insert( + Self::CONTENT_MD5_HEADER.to_string(), + Self::md5_base64(body_content), + ); + headers.insert( + Self::CONTENT_TYPE_HEADER.to_string(), + Self::CONTENT_TYPE_VALUE.to_string(), + ); + } + } + + // Host header + headers.insert(Self::HOST_HEADER.to_string(), host.to_string()); + + // x-acs-* headers + headers.insert( + Self::X_ACS_SIGNATURE_METHOD.to_string(), + Self::SIGNATURE_METHOD_VALUE.to_string(), + ); + headers.insert( + Self::X_ACS_SIGNATURE_NONCE.to_string(), + Uuid::new_v4().to_string(), + ); + headers.insert( + Self::X_ACS_SIGNATURE_VERSION.to_string(), + Self::SIGNATURE_VERSION_VALUE.to_string(), + ); + headers.insert( + Self::X_ACS_VERSION.to_string(), + Self::API_VERSION.to_string(), + ); + + // Security token (if present) + if let Some(token) = security_token { + headers.insert(Self::X_ACS_SECURITY_TOKEN.to_string(), token.to_string()); + } + + headers + } + + fn authorization( + &self, + rest_auth_parameter: &RESTAuthParameter, + token: &DLFToken, + _host: &str, + sign_headers: &HashMap, + ) -> String { + let canonicalized_headers = self.build_canonicalized_headers(sign_headers); + let canonicalized_resource = self.build_canonicalized_resource(rest_auth_parameter); + let string_to_sign = self.build_string_to_sign( + rest_auth_parameter, + sign_headers, + &canonicalized_headers, + &canonicalized_resource, + ); + + let signature = Self::hmac_sha1_base64(&token.access_key_secret, &string_to_sign); + format!("acs {}:{}", token.access_key_id, signature) + } + + fn identifier(&self) -> &str { + Self::IDENTIFIER + } +} + +/// Factory for creating DLF signers based on endpoint configuration. +/// +/// This factory automatically selects the appropriate signer based on the +/// endpoint URI: +/// +/// | Endpoint Pattern | Signer | Algorithm | +/// |-----------------|--------|-----------| +/// | `*-vpc.dlf.aliyuncs.com` | `DLFDefaultSigner` | DLF4-HMAC-SHA256 | +/// | `dlfnext.*.aliyuncs.com` | `DLFOpenApiSigner` | HMAC-SHA1 | +/// | `*openapi*` | `DLFOpenApiSigner` | HMAC-SHA1 | +/// | Other | `DLFDefaultSigner` | DLF4-HMAC-SHA256 | +/// +/// # Example +/// +/// ```ignore +/// use paimon::api::auth::DLFSignerFactory; +/// +/// // Auto-detect from URI +/// let signer = DLFSignerFactory::create_signer("default", "cn-hangzhou"); +/// let algo = DLFSignerFactory::parse_signing_algo_from_uri(Some("http://dlfnext.ajinnbjug.aliyuncs.com")); +/// assert_eq!(algo, "openapi"); +/// ``` +pub struct DLFSignerFactory; + +impl DLFSignerFactory { + /// Create a signer based on the signing algorithm. + pub fn create_signer(signing_algorithm: &str, region: &str) -> Box { + if signing_algorithm == DLFOpenApiSigner::IDENTIFIER { + Box::new(DLFOpenApiSigner) + } else { + Box::new(DLFDefaultSigner::new(region)) + } + } +} diff --git a/crates/paimon/src/api/auth/factory.rs b/crates/paimon/src/api/auth/factory.rs index 58234b2..a4f430c 100644 --- a/crates/paimon/src/api/auth/factory.rs +++ b/crates/paimon/src/api/auth/factory.rs @@ -17,10 +17,98 @@ //! Authentication provider factory. +use crate::api::auth::dlf_provider::DLFTokenLoaderFactory; +use crate::api::auth::{BearerTokenAuthProvider, DLFAuthProvider, DLFToken}; +use crate::api::AuthProvider; use crate::common::{CatalogOptions, Options}; use crate::Error; +use regex::Regex; -use super::{AuthProvider, BearerTokenAuthProvider}; +/// Factory for creating DLF authentication providers. +pub struct DLFAuthProviderFactory; + +impl DLFAuthProviderFactory { + /// OpenAPI identifier. + pub const OPENAPI_IDENTIFIER: &'static str = "openapi"; + /// Default identifier. + pub const DEFAULT_IDENTIFIER: &'static str = "default"; + /// Region pattern for parsing from URI. + const REGION_PATTERN: &'static str = r"(?:pre-)?([a-z]+-[a-z]+(?:-\d+)?)"; + + /// Parse region from DLF endpoint URI. + pub fn parse_region_from_uri(uri: Option<&str>) -> Option { + let uri = uri?; + let re = Regex::new(Self::REGION_PATTERN).ok()?; + let caps = re.captures(uri)?; + caps.get(1).map(|m| m.as_str().to_string()) + } + + /// Parse signing algorithm from URI. + /// + /// Returns "openapi" for public endpoints (dlfnext or openapi in host), + /// otherwise returns "default". + pub fn parse_signing_algo_from_uri(uri: Option<&str>) -> &'static str { + if let Some(uri) = uri { + let host = uri.to_lowercase(); + let host = host + .strip_prefix("http://") + .unwrap_or(host.strip_prefix("https://").unwrap_or(&host)); + let host = host.split('/').next().unwrap_or(""); + let host = host.split(':').next().unwrap_or(""); + + if host.starts_with("dlfnext") { + return Self::OPENAPI_IDENTIFIER; + } + } + Self::DEFAULT_IDENTIFIER + } + + /// Create a DLF authentication provider from options. + /// + /// # Arguments + /// * `options` - The configuration options. + /// + /// # Returns + /// A boxed AuthProvider trait object. + /// + /// # Errors + /// Returns an error if required configuration is missing. + pub fn create_provider(options: &Options) -> Result, Error> { + let uri = options + .get(CatalogOptions::URI) + .ok_or_else(|| Error::ConfigInvalid { + message: "URI is required for DLF authentication".to_string(), + })? + .clone(); + + // Get region from options or parse from URI + let region = options + .get(CatalogOptions::DLF_REGION) + .cloned() + .or_else(|| Self::parse_region_from_uri(Some(&uri))) + .ok_or_else(|| Error::ConfigInvalid { + message: "Could not get region from config or URI. Please set 'dlf.region' or use a standard DLF endpoint URI.".to_string(), + })?; + + // Get signing algorithm from options, or auto-detect from URI + let signing_algorithm = options + .get(CatalogOptions::DLF_SIGNING_ALGORITHM) + .map(|s| s.as_str()) + .filter(|s| *s != "default") + .unwrap_or_else(|| Self::parse_signing_algo_from_uri(Some(&uri))) + .to_string(); + + let dlf_provider = DLFAuthProvider::new( + uri, + region, + signing_algorithm, + DLFToken::from_options(options), + DLFTokenLoaderFactory::create_token_loader(options), + )?; + + Ok(Box::new(dlf_provider)) + } +} /// Factory for creating authentication providers. pub struct AuthProviderFactory; @@ -49,33 +137,39 @@ impl AuthProviderFactory { })?; Ok(Box::new(BearerTokenAuthProvider::new(token))) } - None => Err(Error::ConfigInvalid { - message: "auth provider is required".to_string(), - }), + Some("dlf") => DLFAuthProviderFactory::create_provider(options), Some(unknown) => Err(Error::ConfigInvalid { message: format!("Unknown auth provider: {unknown}"), }), + None => Err(Error::ConfigInvalid { + message: "auth provider is required".to_string(), + }), } } } #[cfg(test)] mod tests { + use crate::api::auth::base::AUTHORIZATION_HEADER_KEY; + use super::super::RESTAuthParameter; use super::*; use std::collections::HashMap; - #[test] - fn test_create_bearer_provider() { + #[tokio::test] + async fn test_create_bearer_provider() { let mut options = Options::new(); options.set(CatalogOptions::TOKEN_PROVIDER, "bear"); options.set(CatalogOptions::TOKEN, "test-token"); - let provider = AuthProviderFactory::create_auth_provider(&options).unwrap(); + let mut provider = AuthProviderFactory::create_auth_provider(&options).unwrap(); let base_header = HashMap::new(); let param = RESTAuthParameter::new("GET", "/test", None, HashMap::new()); - let result = provider.merge_auth_header(base_header, ¶m); + let result = provider + .merge_auth_header(base_header, ¶m) + .await + .unwrap(); assert_eq!( result.get("Authorization"), @@ -98,4 +192,58 @@ mod tests { let result = AuthProviderFactory::create_auth_provider(&options); assert!(result.is_err()); } + + #[tokio::test] + async fn test_create_dlf_provider() { + let mut options = Options::new(); + options.set(CatalogOptions::TOKEN_PROVIDER, "dlf"); + options.set(CatalogOptions::URI, "http://dlf-asdaswfnb.net/"); + options.set(CatalogOptions::DLF_REGION, "cn-hangzhou"); + options.set(CatalogOptions::DLF_ACCESS_KEY_ID, "test_key_id"); + options.set(CatalogOptions::DLF_ACCESS_KEY_SECRET, "test_key_secret"); + + let mut provider = AuthProviderFactory::create_auth_provider(&options).unwrap(); + + let base_header = HashMap::new(); + let param = RESTAuthParameter::new("GET", "/test", None, HashMap::new()); + let result = provider + .merge_auth_header(base_header, ¶m) + .await + .unwrap(); + + assert!(result.contains_key(AUTHORIZATION_HEADER_KEY)); + } + + #[test] + fn test_dlf_provider_missing_region() { + let mut options = Options::new(); + options.set(CatalogOptions::TOKEN_PROVIDER, "dlf"); + options.set(CatalogOptions::URI, "http://example.com/"); + options.set(CatalogOptions::DLF_ACCESS_KEY_ID, "test_key_id"); + options.set(CatalogOptions::DLF_ACCESS_KEY_SECRET, "test_key_secret"); + + let result = AuthProviderFactory::create_auth_provider(&options); + assert!(result.is_err()); + } + + #[test] + fn test_parse_region_from_uri() { + let region = DLFAuthProviderFactory::parse_region_from_uri(Some( + "http://cn-hangzhou-vpc.dlf.aliyuncs.com", + )); + assert_eq!(region, Some("cn-hangzhou".to_string())); + } + + #[test] + fn test_parse_signing_algo_from_uri() { + let algo = DLFAuthProviderFactory::parse_signing_algo_from_uri(Some( + "http://dlfnext.cn-hangzhou.aliyuncs.com", + )); + assert_eq!(algo, "openapi"); + + let algo = DLFAuthProviderFactory::parse_signing_algo_from_uri(Some( + "http://cn-hangzhou-vpc.dlf.aliyuncs.com", + )); + assert_eq!(algo, "default"); + } } diff --git a/crates/paimon/src/api/auth/mod.rs b/crates/paimon/src/api/auth/mod.rs index 219d343..7afb9d8 100644 --- a/crates/paimon/src/api/auth/mod.rs +++ b/crates/paimon/src/api/auth/mod.rs @@ -18,9 +18,12 @@ //! Authentication module for REST API. mod base; -mod bear_provider; +mod bearer_provider; +mod dlf_provider; +mod dlf_signer; mod factory; pub use base::{AuthProvider, RESTAuthFunction, RESTAuthParameter}; -pub use bear_provider::BearerTokenAuthProvider; -pub use factory::AuthProviderFactory; +pub use bearer_provider::BearerTokenAuthProvider; +pub use dlf_provider::{DLFAuthProvider, DLFECSTokenLoader, DLFToken, DLFTokenLoader}; +pub use factory::{AuthProviderFactory, DLFAuthProviderFactory}; diff --git a/crates/paimon/src/api/mod.rs b/crates/paimon/src/api/mod.rs index e1c205a..958323e 100644 --- a/crates/paimon/src/api/mod.rs +++ b/crates/paimon/src/api/mod.rs @@ -19,6 +19,7 @@ //! //! This module provides REST API client, request, and response types. +pub mod api_request; pub mod auth; pub mod resource_paths; pub mod rest_api; @@ -28,9 +29,15 @@ pub mod rest_util; mod api_response; +// Re-export request types +pub use api_request::{ + AlterDatabaseRequest, CreateDatabaseRequest, CreateTableRequest, RenameTableRequest, +}; + // Re-export response types pub use api_response::{ - ConfigResponse, ErrorResponse, ListDatabasesResponse, PagedList, RESTResponse, + AuditRESTResponse, ConfigResponse, ErrorResponse, GetDatabaseResponse, GetTableResponse, + ListDatabasesResponse, ListTablesResponse, PagedList, }; // Re-export error types diff --git a/crates/paimon/src/api/resource_paths.rs b/crates/paimon/src/api/resource_paths.rs index 43577b8..1cbc88f 100644 --- a/crates/paimon/src/api/resource_paths.rs +++ b/crates/paimon/src/api/resource_paths.rs @@ -19,6 +19,8 @@ use crate::common::{CatalogOptions, Options}; +use super::rest_util::RESTUtil; + /// Resource paths for REST API endpoints. #[derive(Clone)] pub struct ResourcePaths { @@ -28,6 +30,8 @@ pub struct ResourcePaths { impl ResourcePaths { const V1: &'static str = "v1"; const DATABASES: &'static str = "databases"; + const TABLES: &'static str = "tables"; + const TABLE_DETAILS: &'static str = "table-details"; /// Create a new ResourcePaths with the given prefix. pub fn new(prefix: &str) -> Self { @@ -62,6 +66,71 @@ impl ResourcePaths { pub fn databases(&self) -> String { format!("{}/{}", self.base_path, Self::DATABASES) } + + /// Get a specific database endpoint path. + pub fn database(&self, name: &str) -> String { + format!( + "{}/{}/{}", + self.base_path, + Self::DATABASES, + RESTUtil::encode_string(name) + ) + } + + /// Get the tables endpoint path. + pub fn tables(&self, database_name: Option<&str>) -> String { + if let Some(db_name) = database_name { + format!( + "{}/{}/{}/{}", + self.base_path, + Self::DATABASES, + RESTUtil::encode_string(db_name), + Self::TABLES + ) + } else { + format!("{}/{}", self.base_path, Self::TABLES) + } + } + + /// Get a specific table endpoint path. + pub fn table(&self, database_name: &str, table_name: &str) -> String { + format!( + "{}/{}/{}/{}/{}", + self.base_path, + Self::DATABASES, + RESTUtil::encode_string(database_name), + Self::TABLES, + RESTUtil::encode_string(table_name) + ) + } + + /// Get the table details endpoint path. + pub fn table_details(&self, database_name: &str) -> String { + format!( + "{}/{}/{}/{}", + self.base_path, + Self::DATABASES, + RESTUtil::encode_string(database_name), + Self::TABLE_DETAILS + ) + } + + /// Get the table token endpoint path. + pub fn table_token(&self, database_name: &str, table_name: &str) -> String { + format!( + "{}/{}/{}/{}/{}/token", + self.base_path, + Self::DATABASES, + RESTUtil::encode_string(database_name), + Self::TABLES, + RESTUtil::encode_string(table_name) + ) + } + + /// Get the rename table endpoint path. + pub fn rename_table(&self) -> String { + format!("{}/{}/rename", self.base_path, Self::TABLES) + } } #[cfg(test)] @@ -72,12 +141,25 @@ mod tests { fn test_resource_paths_basic() { let paths = ResourcePaths::new(""); assert_eq!(paths.databases(), "/v1/databases"); + assert_eq!(paths.tables(None), "/v1/tables"); } #[test] fn test_resource_paths_with_prefix() { let paths = ResourcePaths::new("my-catalog"); assert_eq!(paths.databases(), "/v1/my-catalog/databases"); + assert_eq!( + paths.database("test-db"), + "/v1/my-catalog/databases/test-db" + ); + } + + #[test] + fn test_resource_paths_table() { + let paths = ResourcePaths::new(""); + let table_path = paths.table("my-db", "my-table"); + assert!(table_path.contains("my-db")); + assert!(table_path.contains("my-table")); } #[test] diff --git a/crates/paimon/src/api/rest_api.rs b/crates/paimon/src/api/rest_api.rs index d975d99..f785111 100644 --- a/crates/paimon/src/api/rest_api.rs +++ b/crates/paimon/src/api/rest_api.rs @@ -23,17 +23,56 @@ use std::collections::HashMap; use crate::api::rest_client::HttpClient; +use crate::catalog::Identifier; use crate::common::{CatalogOptions, Options}; +use crate::spec::Schema; use crate::Result; -use super::api_response::{ConfigResponse, ListDatabasesResponse, PagedList}; +use super::api_request::{ + AlterDatabaseRequest, CreateDatabaseRequest, CreateTableRequest, RenameTableRequest, +}; +use super::api_response::{ + ConfigResponse, GetDatabaseResponse, GetTableResponse, ListDatabasesResponse, + ListTablesResponse, PagedList, +}; use super::auth::{AuthProviderFactory, RESTAuthFunction}; use super::resource_paths::ResourcePaths; use super::rest_util::RESTUtil; +/// Validate that a string is not empty after trimming. +/// +/// # Arguments +/// * `value` - The string to validate. +/// * `field_name` - The name of the field for error messages. +/// +/// # Returns +/// `Ok(())` if valid, `Err` if empty. +fn validate_non_empty(value: &str, field_name: &str) -> Result<()> { + if value.trim().is_empty() { + return Err(crate::Error::ConfigInvalid { + message: format!("{} cannot be empty", field_name), + }); + } + Ok(()) +} + +/// Validate that multiple strings are not empty after trimming. +/// +/// # Arguments +/// * `values` - Slice of (value, field_name) pairs to validate. +/// +/// # Returns +/// `Ok(())` if all valid, `Err` if any is empty. +fn validate_non_empty_multi(values: &[(&str, &str)]) -> Result<()> { + for (value, field_name) in values { + validate_non_empty(value, field_name)?; + } + Ok(()) +} + /// REST API wrapper for Paimon catalog operations. /// -/// This struct provides methods for database CRUD operations +/// This struct provides methods for database and table CRUD operations /// through a REST API client. pub struct RESTApi { client: HttpClient, @@ -48,6 +87,8 @@ impl RESTApi { pub const MAX_RESULTS: &'static str = "maxResults"; pub const PAGE_TOKEN: &'static str = "pageToken"; pub const DATABASE_NAME_PATTERN: &'static str = "databaseNamePattern"; + pub const TABLE_NAME_PATTERN: &'static str = "tableNamePattern"; + pub const TABLE_TYPE: &'static str = "tableType"; /// Create a new RESTApi from options. /// @@ -99,7 +140,7 @@ impl RESTApi { RESTUtil::encode_string(warehouse), )]; let config_response: ConfigResponse = client - .get_with_params(&ResourcePaths::config(), &query_params) + .get(&ResourcePaths::config(), Some(&query_params)) .await?; // Merge config response with options (client config takes priority) @@ -130,7 +171,7 @@ impl RESTApi { // ==================== Database Operations ==================== /// List all databases. - pub async fn list_databases(&self) -> Result> { + pub async fn list_databases(&mut self) -> Result> { let mut results = Vec::new(); let mut page_token: Option = None; @@ -151,7 +192,7 @@ impl RESTApi { /// List databases with pagination. pub async fn list_databases_paged( - &self, + &mut self, max_results: Option, page_token: Option<&str>, database_name_pattern: Option<&str>, @@ -172,11 +213,166 @@ impl RESTApi { } let response: ListDatabasesResponse = if params.is_empty() { - self.client.get(&path).await? + self.client.get(&path, None::<&[(&str, &str)]>).await? } else { - self.client.get_with_params(&path, ¶ms).await? + self.client.get(&path, Some(¶ms)).await? }; Ok(PagedList::new(response.databases, response.next_page_token)) } + + /// Create a new database. + pub async fn create_database( + &mut self, + name: &str, + options: Option>, + ) -> Result<()> { + validate_non_empty(name, "database name")?; + let path = self.resource_paths.databases(); + let request = CreateDatabaseRequest::new(name.to_string(), options.unwrap_or_default()); + let _resp: serde_json::Value = self.client.post(&path, &request).await?; + Ok(()) + } + + /// Get database information. + pub async fn get_database(&mut self, name: &str) -> Result { + validate_non_empty(name, "database name")?; + let path = self.resource_paths.database(name); + self.client.get(&path, None::<&[(&str, &str)]>).await + } + + /// Alter database configuration. + pub async fn alter_database( + &mut self, + name: &str, + removals: Vec, + updates: std::collections::HashMap, + ) -> Result<()> { + validate_non_empty(name, "database name")?; + let path = self.resource_paths.database(name); + let request = AlterDatabaseRequest::new(removals, updates); + let _resp: serde_json::Value = self.client.post(&path, &request).await?; + Ok(()) + } + + /// Drop a database. + pub async fn drop_database(&mut self, name: &str) -> Result<()> { + validate_non_empty(name, "database name")?; + let path = self.resource_paths.database(name); + let _resp: serde_json::Value = self.client.delete(&path, None::<&[(&str, &str)]>).await?; + Ok(()) + } + + // ==================== Table Operations ==================== + + /// List all tables in a database. + pub async fn list_tables(&mut self, database: &str) -> Result> { + validate_non_empty(database, "database name")?; + + let mut results = Vec::new(); + let mut page_token: Option = None; + + loop { + let paged = self + .list_tables_paged(database, None, page_token.as_deref(), None, None) + .await?; + let is_empty = paged.elements.is_empty(); + results.extend(paged.elements); + page_token = paged.next_page_token; + if page_token.is_none() || is_empty { + break; + } + } + + Ok(results) + } + + /// List tables with pagination. + pub async fn list_tables_paged( + &mut self, + database: &str, + max_results: Option, + page_token: Option<&str>, + table_name_pattern: Option<&str>, + table_type: Option<&str>, + ) -> Result> { + validate_non_empty(database, "database name")?; + let path = self.resource_paths.tables(Some(database)); + let mut params: Vec<(&str, String)> = Vec::new(); + + if let Some(max) = max_results { + params.push((Self::MAX_RESULTS, max.to_string())); + } + + if let Some(token) = page_token { + params.push((Self::PAGE_TOKEN, token.to_string())); + } + + if let Some(pattern) = table_name_pattern { + params.push((Self::TABLE_NAME_PATTERN, pattern.to_string())); + } + + if let Some(ttype) = table_type { + params.push((Self::TABLE_TYPE, ttype.to_string())); + } + + let response: ListTablesResponse = if params.is_empty() { + self.client.get(&path, None::<&[(&str, &str)]>).await? + } else { + self.client.get(&path, Some(¶ms)).await? + }; + + Ok(PagedList::new( + response.tables.unwrap_or_default(), + response.next_page_token, + )) + } + + /// Create a new table. + pub async fn create_table(&mut self, identifier: &Identifier, schema: Schema) -> Result<()> { + let database = identifier.database(); + let table = identifier.object(); + validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?; + let path = self.resource_paths.tables(Some(database)); + let request = CreateTableRequest::new(identifier.clone(), schema); + let _resp: serde_json::Value = self.client.post(&path, &request).await?; + Ok(()) + } + + /// Get table information. + pub async fn get_table(&mut self, identifier: &Identifier) -> Result { + let database = identifier.database(); + let table = identifier.object(); + validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?; + let path = self.resource_paths.table(database, table); + self.client.get(&path, None::<&[(&str, &str)]>).await + } + + /// Rename a table. + pub async fn rename_table( + &mut self, + source: &Identifier, + destination: &Identifier, + ) -> Result<()> { + validate_non_empty_multi(&[ + (source.database(), "source database name"), + (source.object(), "source table name"), + (destination.database(), "destination database name"), + (destination.object(), "destination table name"), + ])?; + let path = self.resource_paths.rename_table(); + let request = RenameTableRequest::new(source.clone(), destination.clone()); + let _resp: serde_json::Value = self.client.post(&path, &request).await?; + Ok(()) + } + + /// Drop a table. + pub async fn drop_table(&mut self, identifier: &Identifier) -> Result<()> { + let database = identifier.database(); + let table = identifier.object(); + validate_non_empty_multi(&[(database, "database name"), (table, "table name")])?; + let path = self.resource_paths.table(database, table); + let _resp: serde_json::Value = self.client.delete(&path, None::<&[(&str, &str)]>).await?; + Ok(()) + } } diff --git a/crates/paimon/src/api/rest_client.rs b/crates/paimon/src/api/rest_client.rs index b348918..76db48f 100644 --- a/crates/paimon/src/api/rest_client.rs +++ b/crates/paimon/src/api/rest_client.rs @@ -85,17 +85,40 @@ impl HttpClient { Ok(normalized_url.trim_end_matches('/').to_string()) } - /// Perform a GET request and parse the response as JSON. + /// Perform a GET request with optional query parameters. /// /// # Arguments /// * `path` - The path to append to the base URL. + /// * `params` - Optional query parameters as key-value pairs. /// /// # Returns /// The parsed JSON response. - pub async fn get(&self, path: &str) -> Result { + pub async fn get( + &mut self, + path: &str, + params: Option<&[(impl AsRef, impl AsRef)]>, + ) -> Result { let url = self.request_url(path); - let headers = self.build_auth_headers("GET", path, None, HashMap::new()); - let request = self.client.get(&url); + + let params_map: HashMap = match params { + Some(p) => p + .iter() + .map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string())) + .collect(), + None => HashMap::new(), + }; + + let headers = self + .build_auth_headers("GET", path, None, params_map) + .await?; + + let mut request = self.client.get(&url); + if let Some(p) = params { + for (key, value) in p { + request = request.query(&[(key.as_ref(), value.as_ref())]); + } + } + let request = Self::apply_headers(request, &headers); let resp = request.send().await.map_err(|e| Error::UnexpectedError { message: "http get failed".to_string(), @@ -104,34 +127,70 @@ impl HttpClient { self.parse_response(resp).await } - /// Perform a GET request with query parameters. + /// Perform a POST request with a JSON body. /// /// # Arguments /// * `path` - The path to append to the base URL. - /// * `params` - Query parameters as key-value pairs (supports both `&str` and `String`). + /// * `body` - The JSON body to send. /// /// # Returns /// The parsed JSON response. - pub async fn get_with_params( - &self, + pub async fn post( + &mut self, path: &str, - params: &[(impl AsRef, impl AsRef)], + body: &B, ) -> Result { let url = self.request_url(path); - let params_map: HashMap = params - .iter() - .map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string())) - .collect(); - let headers = self.build_auth_headers("GET", path, None, params_map.clone()); + let body_str = serde_json::to_string(body).ok(); + let headers = self + .build_auth_headers("POST", path, body_str.as_deref(), HashMap::new()) + .await?; + let request = self.client.post(&url).json(body); + let request = Self::apply_headers(request, &headers); + let resp = request.send().await.map_err(|e| Error::UnexpectedError { + message: "http post failed".to_string(), + source: Some(Box::new(e)), + })?; + self.parse_response(resp).await + } - let mut request = self.client.get(&url); - for (key, value) in params { - request = request.query(&[(key.as_ref(), value.as_ref())]); + /// Perform a DELETE request with optional query parameters. + /// + /// # Arguments + /// * `path` - The path to append to the base URL. + /// * `params` - Optional query parameters as key-value pairs. + /// + /// # Returns + /// The parsed JSON response. + pub async fn delete( + &mut self, + path: &str, + params: Option<&[(impl AsRef, impl AsRef)]>, + ) -> Result { + let url = self.request_url(path); + + let params_map: HashMap = match params { + Some(p) => p + .iter() + .map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string())) + .collect(), + None => HashMap::new(), + }; + + let headers = self + .build_auth_headers("DELETE", path, None, params_map) + .await?; + + let mut request = self.client.delete(&url); + if let Some(p) = params { + for (key, value) in p { + request = request.query(&[(key.as_ref(), value.as_ref())]); + } } let request = Self::apply_headers(request, &headers); let resp = request.send().await.map_err(|e| Error::UnexpectedError { - message: "http get failed".to_string(), + message: "http delete failed".to_string(), source: Some(Box::new(e)), })?; self.parse_response(resp).await @@ -143,19 +202,19 @@ impl HttpClient { } /// Build auth headers for a request. - fn build_auth_headers( - &self, + async fn build_auth_headers( + &mut self, method: &str, path: &str, data: Option<&str>, params: HashMap, - ) -> HashMap { - if let Some(ref auth_fn) = self.auth_function { + ) -> Result> { + if let Some(ref mut auth_fn) = self.auth_function { let parameter = RESTAuthParameter::new(method, path, data.map(|s| s.to_string()), params); - auth_fn.apply(¶meter) + auth_fn.apply(¶meter).await } else { - HashMap::new() + Ok(HashMap::new()) } } @@ -203,6 +262,14 @@ impl HttpClient { source: Some(Box::new(e)), })?; + // Handle empty response body - return null as default for types like serde_json::Value + if text.trim().is_empty() { + return serde_json::from_str("null").map_err(|e| Error::UnexpectedError { + message: "failed to parse empty response".to_string(), + source: Some(Box::new(e)), + }); + } + serde_json::from_str(&text).map_err(|e| Error::UnexpectedError { message: "failed to parse json".to_string(), source: Some(Box::new(e)), diff --git a/crates/paimon/src/common/options.rs b/crates/paimon/src/common/options.rs index a6adf69..a469a07 100644 --- a/crates/paimon/src/common/options.rs +++ b/crates/paimon/src/common/options.rs @@ -38,8 +38,37 @@ impl CatalogOptions { /// Authentication token. pub const TOKEN: &'static str = "token"; + /// Data token enabled flag. + pub const DATA_TOKEN_ENABLED: &'static str = "data-token.enabled"; + /// Prefix for catalog resources. pub const PREFIX: &'static str = "prefix"; + + // DLF (Data Lake Formation) configuration options + + /// DLF region. + pub const DLF_REGION: &'static str = "dlf.region"; + + /// DLF access key ID. + pub const DLF_ACCESS_KEY_ID: &'static str = "dlf.access-key-id"; + + /// DLF access key secret. + pub const DLF_ACCESS_KEY_SECRET: &'static str = "dlf.access-key-secret"; + + /// DLF security token (optional, for temporary credentials). + pub const DLF_ACCESS_SECURITY_TOKEN: &'static str = "dlf.security-token"; + + /// DLF signing algorithm (default or openapi). + pub const DLF_SIGNING_ALGORITHM: &'static str = "dlf.signing-algorithm"; + + /// DLF token loader type (e.g., "ecs"). + pub const DLF_TOKEN_LOADER: &'static str = "dlf.token-loader"; + + /// DLF ECS metadata URL. + pub const DLF_TOKEN_ECS_METADATA_URL: &'static str = "dlf.token-ecs-metadata-url"; + + /// DLF ECS role name. + pub const DLF_TOKEN_ECS_ROLE_NAME: &'static str = "dlf.token-ecs-role-name"; } /// Configuration options container. diff --git a/crates/paimon/tests/mock_server.rs b/crates/paimon/tests/mock_server.rs index 32ced4d..fee16c7 100644 --- a/crates/paimon/tests/mock_server.rs +++ b/crates/paimon/tests/mock_server.rs @@ -21,22 +21,33 @@ //! for testing purposes. use axum::{ - extract::{Extension, Json, Query}, + extract::{Extension, Json, Path, Query}, http::StatusCode, response::IntoResponse, routing::get, Router, }; -use std::collections::HashMap; +use serde_json::json; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use tokio::task::JoinHandle; -use paimon::api::{ConfigResponse, ErrorResponse, ListDatabasesResponse, ResourcePaths}; +use paimon::api::{ + AlterDatabaseRequest, AuditRESTResponse, ConfigResponse, ErrorResponse, GetDatabaseResponse, + GetTableResponse, ListDatabasesResponse, ListTablesResponse, RenameTableRequest, ResourcePaths, +}; #[derive(Clone, Debug, Default)] struct MockState { - databases: HashMap, + databases: HashMap, + tables: HashMap, + no_permission_databases: HashSet, + no_permission_tables: HashSet, + /// ECS metadata role name (for token loader testing) + ecs_role_name: Option, + /// ECS metadata token (for token loader testing) + ecs_token: Option, } #[derive(Clone)] @@ -52,7 +63,7 @@ pub struct RESTServer { } impl RESTServer { - /// Create a new RESTServer with initial databases (backward compatibility). + /// Create a new RESTServer with initial databases. pub fn new( warehouse: String, data_path: String, @@ -62,14 +73,28 @@ impl RESTServer { let prefix = config.defaults.get("prefix").cloned().unwrap_or_default(); // Create database set for initial databases - let databases: HashMap = - initial_dbs.into_iter().map(|name| (name, ())).collect(); + let databases: HashMap = initial_dbs + .into_iter() + .map(|name| { + let response = GetDatabaseResponse::new( + Some(name.clone()), + Some(name.clone()), + None, + HashMap::new(), + AuditRESTResponse::new(None, None, None, None, None), + ); + (name, response) + }) + .collect(); RESTServer { data_path, config, warehouse, - inner: Arc::new(Mutex::new(MockState { databases })), + inner: Arc::new(Mutex::new(MockState { + databases, + ..Default::default() + })), resource_paths: ResourcePaths::new(&prefix), addr: None, server_handle: None, @@ -90,7 +115,7 @@ impl RESTServer { let err = ErrorResponse::new( None, None, - Some(format!("Warehouse {warehouse} not found")), + Some(format!("Warehouse {} not found", warehouse)), Some(404), ); return (StatusCode::NOT_FOUND, Json(err)).into_response(); @@ -107,37 +132,516 @@ impl RESTServer { let response = ListDatabasesResponse::new(dbs, None); (StatusCode::OK, Json(response)) } + /// Handle POST /databases - create a new database. + pub async fn create_database( + Extension(state): Extension>, + Json(payload): Json, + ) -> impl IntoResponse { + let name = match payload.get("name").and_then(|n| n.as_str()) { + Some(n) => n.to_string(), + None => { + let err = + ErrorResponse::new(None, None, Some("Missing name".to_string()), Some(400)); + return (StatusCode::BAD_REQUEST, Json(err)).into_response(); + } + }; - // ====================== Server Control ==================== - /// Get the warehouse path. - #[allow(dead_code)] - pub fn warehouse(&self) -> &str { - &self.warehouse + let mut s = state.inner.lock().unwrap(); + if let std::collections::hash_map::Entry::Vacant(e) = s.databases.entry(name.clone()) { + let response = GetDatabaseResponse::new( + Some(name.clone()), + Some(name.clone()), + None, + HashMap::new(), + AuditRESTResponse::new(None, None, None, None, None), + ); + e.insert(response); + (StatusCode::OK, Json(serde_json::json!(""))).into_response() + } else { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(name), + Some("Already Exists".to_string()), + Some(409), + ); + (StatusCode::CONFLICT, Json(err)).into_response() + } } + /// Handle GET /databases/:name - get a specific database. + pub async fn get_database( + Path(name): Path, + Extension(state): Extension>, + ) -> impl IntoResponse { + let s = state.inner.lock().unwrap(); - /// Get the resource paths. - pub fn resource_paths(&self) -> &ResourcePaths { - &self.resource_paths + if s.no_permission_databases.contains(&name) { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(name.clone()), + Some("No Permission".to_string()), + Some(403), + ); + return (StatusCode::FORBIDDEN, Json(err)).into_response(); + } + + if let Some(response) = s.databases.get(&name) { + (StatusCode::OK, Json(response.clone())).into_response() + } else { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(name.clone()), + Some("Not Found".to_string()), + Some(404), + ); + (StatusCode::NOT_FOUND, Json(err)).into_response() + } + } + + /// Handle POST /databases/:name - alter database configuration. + pub async fn alter_database( + Path(name): Path, + Extension(state): Extension>, + Json(request): Json, + ) -> impl IntoResponse { + let mut s = state.inner.lock().unwrap(); + + if s.no_permission_databases.contains(&name) { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(name.clone()), + Some("No Permission".to_string()), + Some(403), + ); + return (StatusCode::FORBIDDEN, Json(err)).into_response(); + } + + if let Some(response) = s.databases.get_mut(&name) { + // Apply removals + for key in &request.removals { + response.options.remove(key); + } + // Apply updates + response.options.extend(request.updates); + (StatusCode::OK, Json(serde_json::json!(""))).into_response() + } else { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(name.clone()), + Some("Not Found".to_string()), + Some(404), + ); + (StatusCode::NOT_FOUND, Json(err)).into_response() + } + } + + /// Handle DELETE /databases/:name - drop a database. + pub async fn drop_database( + Path(name): Path, + Extension(state): Extension>, + ) -> impl IntoResponse { + let mut s = state.inner.lock().unwrap(); + + if s.no_permission_databases.contains(&name) { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(name.clone()), + Some("No Permission".to_string()), + Some(403), + ); + return (StatusCode::FORBIDDEN, Json(err)).into_response(); + } + + if s.databases.remove(&name).is_some() { + // Also remove all tables in this database + let prefix = format!("{}.", name); + s.tables.retain(|key, _| !key.starts_with(&prefix)); + s.no_permission_tables + .retain(|key| !key.starts_with(&prefix)); + (StatusCode::OK, Json(serde_json::json!(""))).into_response() + } else { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(name.clone()), + Some("Not Found".to_string()), + Some(404), + ); + (StatusCode::NOT_FOUND, Json(err)).into_response() + } + } + + /// Handle GET /databases/:db/tables - list all tables in a database. + pub async fn list_tables( + Path(db): Path, + Extension(state): Extension>, + ) -> impl IntoResponse { + let s = state.inner.lock().unwrap(); + + if s.no_permission_databases.contains(&db) { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(db.clone()), + Some("No Permission".to_string()), + Some(403), + ); + return (StatusCode::FORBIDDEN, Json(err)).into_response(); + } + + if !s.databases.contains_key(&db) { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(db.clone()), + Some("Not Found".to_string()), + Some(404), + ); + return (StatusCode::NOT_FOUND, Json(err)).into_response(); + } + + let prefix = format!("{}.", db); + let mut tables: Vec = s + .tables + .keys() + .filter_map(|key| { + if key.starts_with(&prefix) { + Some(key[prefix.len()..].to_string()) + } else { + None + } + }) + .collect(); + tables.sort(); + + let response = ListTablesResponse::new(Some(tables), None); + (StatusCode::OK, Json(response)).into_response() + } + + /// Handle POST /databases/:db/tables - create a new table. + pub async fn create_table( + Path(db): Path, + Extension(state): Extension>, + Json(payload): Json, + ) -> impl IntoResponse { + // Extract table name from payload + let table_name = payload + .get("identifier") + .and_then(|id| id.get("object")) + .and_then(|o| o.as_str()) + .map(|s| s.to_string()); + + let table_name = match table_name { + Some(name) => name, + None => { + let err = ErrorResponse::new( + None, + None, + Some("Missing table name in identifier".to_string()), + Some(400), + ); + return (StatusCode::BAD_REQUEST, Json(err)).into_response(); + } + }; + + let mut s = state.inner.lock().unwrap(); + + // Check database exists + if !s.databases.contains_key(&db) { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(db.clone()), + Some("Not Found".to_string()), + Some(404), + ); + return (StatusCode::NOT_FOUND, Json(err)).into_response(); + } + + let key = format!("{}.{}", db, table_name); + if s.tables.contains_key(&key) { + let err = ErrorResponse::new( + Some("table".to_string()), + Some(table_name), + Some("Already Exists".to_string()), + Some(409), + ); + return (StatusCode::CONFLICT, Json(err)).into_response(); + } + + // Create table response + let response = GetTableResponse::new( + Some(table_name.clone()), + Some(table_name), + None, + Some(true), + None, + None, + AuditRESTResponse::new(None, None, None, None, None), + ); + s.tables.insert(key, response); + (StatusCode::OK, Json(serde_json::json!(""))).into_response() + } + + /// Handle GET /databases/:db/tables/:table - get a specific table. + pub async fn get_table( + Path((db, table)): Path<(String, String)>, + Extension(state): Extension>, + ) -> impl IntoResponse { + let s = state.inner.lock().unwrap(); + + let key = format!("{}.{}", db, table); + if s.no_permission_tables.contains(&key) { + let err = ErrorResponse::new( + Some("table".to_string()), + Some(table.clone()), + Some("No Permission".to_string()), + Some(403), + ); + return (StatusCode::FORBIDDEN, Json(err)).into_response(); + } + + if let Some(response) = s.tables.get(&key) { + return (StatusCode::OK, Json(response.clone())).into_response(); + } + + if !s.databases.contains_key(&db) { + let err = ErrorResponse::new( + Some("database".to_string()), + Some(db), + Some("Not Found".to_string()), + Some(404), + ); + return (StatusCode::NOT_FOUND, Json(err)).into_response(); + } + + let err = ErrorResponse::new( + Some("table".to_string()), + Some(table), + Some("Not Found".to_string()), + Some(404), + ); + (StatusCode::NOT_FOUND, Json(err)).into_response() } + /// Handle DELETE /databases/:db/tables/:table - drop a table. + pub async fn drop_table( + Path((db, table)): Path<(String, String)>, + Extension(state): Extension>, + ) -> impl IntoResponse { + let mut s = state.inner.lock().unwrap(); + + let key = format!("{}.{}", db, table); + if s.no_permission_tables.contains(&key) { + let err = ErrorResponse::new( + Some("table".to_string()), + Some(table.clone()), + Some("No Permission".to_string()), + Some(403), + ); + return (StatusCode::FORBIDDEN, Json(err)).into_response(); + } + + if s.tables.remove(&key).is_some() { + s.no_permission_tables.remove(&key); + (StatusCode::OK, Json(serde_json::json!(""))).into_response() + } else { + let err = ErrorResponse::new( + Some("table".to_string()), + Some(table), + Some("Not Found".to_string()), + Some(404), + ); + (StatusCode::NOT_FOUND, Json(err)).into_response() + } + } + + /// Handle POST /rename-table - rename a table. + pub async fn rename_table( + Extension(state): Extension>, + Json(request): Json, + ) -> impl IntoResponse { + let mut s = state.inner.lock().unwrap(); + + let source_key = format!("{}.{}", request.source.database(), request.source.object()); + let dest_key = format!( + "{}.{}", + request.destination.database(), + request.destination.object() + ); + + // Check source table permission + if s.no_permission_tables.contains(&source_key) { + let err = ErrorResponse::new( + Some("table".to_string()), + Some(request.source.object().to_string()), + Some("No Permission".to_string()), + Some(403), + ); + return (StatusCode::FORBIDDEN, Json(err)).into_response(); + } + + // Check if source table exists + if let Some(table_response) = s.tables.remove(&source_key) { + // Check if destination already exists + if s.tables.contains_key(&dest_key) { + // Restore source table + s.tables.insert(source_key, table_response); + let err = ErrorResponse::new( + Some("table".to_string()), + Some(dest_key.clone()), + Some("Already Exists".to_string()), + Some(409), + ); + return (StatusCode::CONFLICT, Json(err)).into_response(); + } + + // Update the table name in response and insert at new location + let new_table_response = GetTableResponse::new( + Some(request.destination.object().to_string()), + Some(request.destination.object().to_string()), + table_response.path, + table_response.is_external, + table_response.schema_id, + table_response.schema, + table_response.audit, + ); + s.tables.insert(dest_key.clone(), new_table_response); + + // Update permission tracking if needed + if s.no_permission_tables.remove(&source_key) { + s.no_permission_tables.insert(dest_key.clone()); + } + + (StatusCode::OK, Json(serde_json::json!(""))).into_response() + } else { + let err = ErrorResponse::new( + Some("table".to_string()), + Some(source_key), + Some("Not Found".to_string()), + Some(404), + ); + (StatusCode::NOT_FOUND, Json(err)).into_response() + } + } + // ====================== Server Control ==================== /// Add a database to the server state. + #[allow(dead_code)] pub fn add_database(&self, name: &str) { let mut s = self.inner.lock().unwrap(); - if !s.databases.contains_key(name) { - s.databases.insert(name.to_string(), ()); - } + s.databases.entry(name.to_string()).or_insert_with(|| { + GetDatabaseResponse::new( + Some(name.to_string()), + Some(name.to_string()), + None, + HashMap::new(), + AuditRESTResponse::new(None, None, None, None, None), + ) + }); + } + /// Add a no-permission database to the server state. + #[allow(dead_code)] + pub fn add_no_permission_database(&self, name: &str) { + let mut s = self.inner.lock().unwrap(); + s.no_permission_databases.insert(name.to_string()); } + /// Add a table to the server state. + #[allow(dead_code)] + pub fn add_table(&self, database: &str, table: &str) { + let mut s = self.inner.lock().unwrap(); + s.databases.entry(database.to_string()).or_insert_with(|| { + // Auto-create database if not exists + GetDatabaseResponse::new( + Some(database.to_string()), + Some(database.to_string()), + None, + HashMap::new(), + AuditRESTResponse::new(None, None, None, None, None), + ) + }); + + let key = format!("{}.{}", database, table); + s.tables.entry(key).or_insert_with(|| { + GetTableResponse::new( + Some(table.to_string()), + Some(table.to_string()), + None, + Some(true), + None, + None, + AuditRESTResponse::new(None, None, None, None, None), + ) + }); + } + + /// Add a no-permission table to the server state. + #[allow(dead_code)] + pub fn add_no_permission_table(&self, database: &str, table: &str) { + let mut s = self.inner.lock().unwrap(); + s.no_permission_tables + .insert(format!("{}.{}", database, table)); + } /// Get the server URL. pub fn url(&self) -> Option { - self.addr.map(|a| format!("http://{a}")) + self.addr.map(|a| format!("http://{}", a)) + } + /// Get the warehouse path. + #[allow(dead_code)] + pub fn warehouse(&self) -> &str { + &self.warehouse } + /// Get the resource paths. + pub fn resource_paths(&self) -> &ResourcePaths { + &self.resource_paths + } /// Get the server address. #[allow(dead_code)] pub fn addr(&self) -> Option { self.addr } + + /// Set ECS metadata role name and token for token loader testing. + #[allow(dead_code)] + pub fn set_ecs_metadata(&self, role_name: &str, token: serde_json::Value) { + let mut s = self.inner.lock().unwrap(); + s.ecs_role_name = Some(role_name.to_string()); + s.ecs_token = Some(token); + } + + /// Handle GET /ram/security-credential/:role - ECS metadata endpoint. + pub async fn get_ecs_metadata( + Path(role): Path, + Extension(state): Extension>, + ) -> impl IntoResponse { + let s = state.inner.lock().unwrap(); + + // If role_name is set and matches, return the token + if let Some(expected_role) = &s.ecs_role_name { + if &role == expected_role { + if let Some(token) = &s.ecs_token { + return (StatusCode::OK, Json(token.clone())).into_response(); + } + } + } + + ( + StatusCode::NOT_FOUND, + Json(json!({"error": "Role not found"})), + ) + .into_response() + } + + /// Handle GET /ram/security-credential/ - ECS metadata endpoint (list roles). + pub async fn list_ecs_roles(Extension(state): Extension>) -> impl IntoResponse { + let s = state.inner.lock().unwrap(); + + if let Some(role_name) = &s.ecs_role_name { + (StatusCode::OK, role_name.clone()).into_response() + } else { + ( + StatusCode::NOT_FOUND, + Json(json!({"error": "No role configured"})), + ) + .into_response() + } + } } impl Drop for RESTServer { @@ -174,8 +678,35 @@ pub async fn start_mock_server( .route("/v1/config", get(RESTServer::get_config)) // Database routes .route( - &format!("{prefix}/databases"), - get(RESTServer::list_databases), + &format!("{}/databases", prefix), + get(RESTServer::list_databases).post(RESTServer::create_database), + ) + .route( + &format!("{}/databases/:name", prefix), + get(RESTServer::get_database) + .post(RESTServer::alter_database) + .delete(RESTServer::drop_database), + ) + .route( + &format!("{}/databases/:db/tables", prefix), + get(RESTServer::list_tables).post(RESTServer::create_table), + ) + .route( + &format!("{}/databases/:db/tables/:table", prefix), + get(RESTServer::get_table).delete(RESTServer::drop_table), + ) + .route( + &format!("{}/tables/rename", prefix), + axum::routing::post(RESTServer::rename_table), + ) + // ECS metadata endpoints (for token loader testing) + .route( + "/ram/security-credentials/", + get(RESTServer::list_ecs_roles), + ) + .route( + "/ram/security-credentials/:role", + get(RESTServer::get_ecs_metadata), ) .layer(Extension(state)); @@ -186,7 +717,7 @@ pub async fn start_mock_server( let server_handle = tokio::spawn(async move { if let Err(e) = axum::serve(listener, app.into_make_service()).await { - eprintln!("mock server error: {e}"); + eprintln!("mock server error: {}", e); } }); diff --git a/crates/paimon/tests/rest_api_test.rs b/crates/paimon/tests/rest_api_test.rs index 0932523..1c06809 100644 --- a/crates/paimon/tests/rest_api_test.rs +++ b/crates/paimon/tests/rest_api_test.rs @@ -18,20 +18,24 @@ //! Integration tests for REST API. //! //! These tests use a mock server to verify the REST API client behavior. +//! Both the mock server and API client run asynchronously using tokio. use std::collections::HashMap; +use paimon::api::auth::{DLFECSTokenLoader, DLFToken, DLFTokenLoader}; use paimon::api::rest_api::RESTApi; use paimon::api::ConfigResponse; +use paimon::catalog::Identifier; use paimon::common::Options; +use serde_json::json; mod mock_server; use mock_server::{start_mock_server, RESTServer}; - /// Helper struct to hold test resources. struct TestContext { server: RESTServer, api: RESTApi, + url: String, } /// Helper function to set up a test environment with a custom prefix. @@ -45,14 +49,14 @@ async fn setup_test_server(initial_dbs: Vec<&str>) -> TestContext { let initial: Vec = initial_dbs.iter().map(|s| s.to_string()).collect(); // Start server with config let server = start_mock_server( - "test_warehouse".to_string(), - "/tmp/test_warehouse".to_string(), + "test_warehouse".to_string(), // warehouse + "/tmp/test_warehouse".to_string(), // data_path config, initial, ) .await; let token = "test_token"; - let url = server.url().expect("server url"); + let url = server.url().expect("Failed to get server URL"); let mut options = Options::new(); options.set("uri", &url); options.set("warehouse", "test_warehouse"); @@ -63,14 +67,13 @@ async fn setup_test_server(initial_dbs: Vec<&str>) -> TestContext { .await .expect("Failed to create RESTApi"); - TestContext { server, api } + TestContext { server, api, url } } // ==================== Database Tests ==================== - #[tokio::test] async fn test_list_databases() { - let ctx = setup_test_server(vec!["default", "test_db1", "prod_db"]).await; + let mut ctx = setup_test_server(vec!["default", "test_db1", "prod_db"]).await; let dbs = ctx.api.list_databases().await.unwrap(); @@ -80,21 +83,398 @@ async fn test_list_databases() { } #[tokio::test] -async fn test_list_databases_empty() { - let ctx = setup_test_server(vec![]).await; +async fn test_create_database() { + let mut ctx = setup_test_server(vec!["default"]).await; + // Create new database + let result = ctx.api.create_database("new_db", None).await; + assert!(result.is_ok(), "failed to create database: {:?}", result); + + // Verify creation let dbs = ctx.api.list_databases().await.unwrap(); - assert!(dbs.is_empty()); + assert!(dbs.contains(&"new_db".to_string())); + + // Duplicate creation should fail + let result = ctx.api.create_database("new_db", None).await; + assert!(result.is_err(), "creating duplicate database should fail"); } #[tokio::test] -async fn test_list_databases_add_after_creation() { +async fn test_get_database() { + let mut ctx = setup_test_server(vec!["default"]).await; + + let db_resp = ctx.api.get_database("default").await.unwrap(); + assert_eq!(db_resp.name, Some("default".to_string())); +} + +#[tokio::test] +async fn test_error_responses_status_mapping() { let ctx = setup_test_server(vec!["default"]).await; - // Add a new database after server creation - ctx.server.add_database("new_db"); + // Add no-permission database + ctx.server.add_no_permission_database("secret"); + + // GET on no-permission database -> 403 + // Use the prefix from config (v1/mock-test) + let url = format!("{}/v1/mock-test/databases/{}", ctx.url, "secret"); + let result = reqwest::get(&url).await; + match result { + Ok(resp) => { + assert_eq!(resp.status(), 403); + let j: serde_json::Value = resp.json().await.unwrap(); + assert_eq!( + j.get("resourceType").and_then(|v| v.as_str()), + Some("database") + ); + assert_eq!( + j.get("resourceName").and_then(|v| v.as_str()), + Some("secret") + ); + assert_eq!(j.get("code").and_then(|v| v.as_u64()), Some(403)); + } + Err(e) => panic!("Expected 403 response, got error: {:?}", e), + } + + // POST create existing database -> 409 + let body = json!({"name": "default", "properties": {}}); + let client = reqwest::Client::new(); + let resp = client + .post(format!("{}/v1/mock-test/databases", ctx.url)) + .json(&body) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 409); + + let j2: serde_json::Value = resp.json().await.unwrap(); + assert_eq!( + j2.get("resourceType").and_then(|v| v.as_str()), + Some("database") + ); + assert_eq!( + j2.get("resourceName").and_then(|v| v.as_str()), + Some("default") + ); + assert_eq!(j2.get("code").and_then(|v| v.as_u64()), Some(409)); +} + +// ==================== Table Tests ==================== + +#[tokio::test] +async fn test_list_tables_and_get_table() { + let mut ctx = setup_test_server(vec!["default"]).await; + + // Add tables + ctx.server.add_table("default", "table1"); + ctx.server.add_table("default", "table2"); + + // List tables + let tables = ctx.api.list_tables("default").await.unwrap(); + assert!(tables.contains(&"table1".to_string())); + assert!(tables.contains(&"table2".to_string())); + + // Get table + let table_resp = ctx + .api + .get_table(&Identifier::new("default", "table1")) + .await + .unwrap(); + assert_eq!(table_resp.id.unwrap_or_default(), "table1"); +} + +#[tokio::test] +async fn test_get_table_not_found() { + let mut ctx = setup_test_server(vec!["default"]).await; + + let result = ctx + .api + .get_table(&Identifier::new("default", "non_existent_table")) + .await; + assert!(result.is_err(), "getting non-existent table should fail"); +} + +#[tokio::test] +async fn test_list_tables_empty_database() { + let mut ctx = setup_test_server(vec!["default"]).await; + let tables = ctx.api.list_tables("default").await.unwrap(); + assert!( + tables.is_empty(), + "expected empty tables list, got: {:?}", + tables + ); +} + +#[tokio::test] +async fn test_multiple_databases_with_tables() { + let mut ctx = setup_test_server(vec!["db1", "db2"]).await; + + // Add tables to different databases + ctx.server.add_table("db1", "table1_db1"); + ctx.server.add_table("db1", "table2_db1"); + ctx.server.add_table("db2", "table1_db2"); + + // Verify db1 tables + let tables_db1 = ctx.api.list_tables("db1").await.unwrap(); + assert_eq!(tables_db1.len(), 2); + assert!(tables_db1.contains(&"table1_db1".to_string())); + assert!(tables_db1.contains(&"table2_db1".to_string())); + + // Verify db2 tables + let tables_db2 = ctx.api.list_tables("db2").await.unwrap(); + assert_eq!(tables_db2.len(), 1); + assert!(tables_db2.contains(&"table1_db2".to_string())); +} + +// ==================== Database Alter/Drop Tests ==================== + +#[tokio::test] +async fn test_alter_database() { + let mut ctx = setup_test_server(vec!["default"]).await; + + // Alter database with updates + let mut updates = HashMap::new(); + updates.insert("key1".to_string(), "value1".to_string()); + updates.insert("key2".to_string(), "value2".to_string()); + + let result = ctx.api.alter_database("default", vec![], updates).await; + assert!(result.is_ok(), "failed to alter database: {:?}", result); + + // Verify the updates by getting the database + let db_resp = ctx.api.get_database("default").await.unwrap(); + assert_eq!(db_resp.options.get("key1"), Some(&"value1".to_string())); + assert_eq!(db_resp.options.get("key2"), Some(&"value2".to_string())); + + // Alter database with removals + let result = ctx + .api + .alter_database("default", vec!["key1".to_string()], HashMap::new()) + .await; + assert!(result.is_ok(), "failed to remove key: {:?}", result); + + let db_resp = ctx.api.get_database("default").await.unwrap(); + assert!(!db_resp.options.contains_key("key1")); + assert_eq!(db_resp.options.get("key2"), Some(&"value2".to_string())); +} + +#[tokio::test] +async fn test_alter_database_not_found() { + let mut ctx = setup_test_server(vec!["default"]).await; + + let result = ctx + .api + .alter_database("non_existent", vec![], HashMap::new()) + .await; + assert!( + result.is_err(), + "altering non-existent database should fail" + ); +} + +#[tokio::test] +async fn test_drop_database() { + let mut ctx = setup_test_server(vec!["default", "to_drop"]).await; + + // Verify database exists let dbs = ctx.api.list_databases().await.unwrap(); - assert!(dbs.contains(&"default".to_string())); - assert!(dbs.contains(&"new_db".to_string())); + assert!(dbs.contains(&"to_drop".to_string())); + + // Drop database + let result = ctx.api.drop_database("to_drop").await; + assert!(result.is_ok(), "failed to drop database: {:?}", result); + + // Verify database is gone + let dbs = ctx.api.list_databases().await.unwrap(); + assert!(!dbs.contains(&"to_drop".to_string())); + + // Dropping non-existent database should fail + let result = ctx.api.drop_database("to_drop").await; + assert!( + result.is_err(), + "dropping non-existent database should fail" + ); +} + +#[tokio::test] +async fn test_drop_database_no_permission() { + let mut ctx = setup_test_server(vec!["default"]).await; + ctx.server.add_no_permission_database("secret"); + + let result = ctx.api.drop_database("secret").await; + assert!( + result.is_err(), + "dropping no-permission database should fail" + ); +} + +// ==================== Table Create/Drop Tests ==================== + +#[tokio::test] +async fn test_create_table() { + let mut ctx = setup_test_server(vec!["default"]).await; + + // Create a simple schema using builder + use paimon::spec::{DataType, Schema}; + let schema = Schema::builder() + .column("id", DataType::BigInt(paimon::spec::BigIntType::new())) + .column( + "name", + DataType::VarChar(paimon::spec::VarCharType::new(255).unwrap()), + ) + .build() + .expect("Failed to build schema"); + + let result = ctx + .api + .create_table(&Identifier::new("default", "new_table"), schema) + .await; + assert!(result.is_ok(), "failed to create table: {:?}", result); + + // Verify table exists + let tables = ctx.api.list_tables("default").await.unwrap(); + assert!(tables.contains(&"new_table".to_string())); + + // Get the table + let table_resp = ctx + .api + .get_table(&Identifier::new("default", "new_table")) + .await + .unwrap(); + assert_eq!(table_resp.name, Some("new_table".to_string())); +} + +#[tokio::test] +async fn test_drop_table() { + let mut ctx = setup_test_server(vec!["default"]).await; + + // Add a table + ctx.server.add_table("default", "table_to_drop"); + + // Verify table exists + let tables = ctx.api.list_tables("default").await.unwrap(); + assert!(tables.contains(&"table_to_drop".to_string())); + + // Drop table + let result = ctx + .api + .drop_table(&Identifier::new("default", "table_to_drop")) + .await; + assert!(result.is_ok(), "failed to drop table: {:?}", result); + + // Verify table is gone + let tables = ctx.api.list_tables("default").await.unwrap(); + assert!(!tables.contains(&"table_to_drop".to_string())); + + // Dropping non-existent table should fail + let result = ctx + .api + .drop_table(&Identifier::new("default", "table_to_drop")) + .await; + assert!(result.is_err(), "dropping non-existent table should fail"); +} + +#[tokio::test] +async fn test_drop_table_no_permission() { + let mut ctx = setup_test_server(vec!["default"]).await; + ctx.server + .add_no_permission_table("default", "secret_table"); + + let result = ctx + .api + .drop_table(&Identifier::new("default", "secret_table")) + .await; + assert!(result.is_err(), "dropping no-permission table should fail"); +} + +// ==================== Rename Table Tests ==================== + +#[tokio::test] +async fn test_rename_table() { + let mut ctx = setup_test_server(vec!["default"]).await; + + // Add a table + ctx.server.add_table("default", "old_table"); + + // Rename table + let result = ctx + .api + .rename_table( + &Identifier::new("default", "old_table"), + &Identifier::new("default", "new_table"), + ) + .await; + assert!(result.is_ok(), "failed to rename table: {:?}", result); + + // Verify old table is gone + let tables = ctx.api.list_tables("default").await.unwrap(); + assert!(!tables.contains(&"old_table".to_string())); + + // Verify new table exists + assert!(tables.contains(&"new_table".to_string())); + + // Get the renamed table + let table_resp = ctx + .api + .get_table(&Identifier::new("default", "new_table")) + .await + .unwrap(); + assert_eq!(table_resp.name, Some("new_table".to_string())); +} + +// ==================== Token Loader Tests ==================== + +#[tokio::test] +async fn test_ecs_loader_token() { + let prefix = "mock-test"; + let mut defaults = HashMap::new(); + defaults.insert("prefix".to_string(), prefix.to_string()); + let config = ConfigResponse::new(defaults); + + let initial: Vec = vec!["default".to_string()]; + let server = start_mock_server( + "test_warehouse".to_string(), + "/tmp/test_warehouse".to_string(), + config, + initial, + ) + .await; + + let role_name = "test_role"; + let token_json = json!({ + "AccessKeyId": "AccessKeyId", + "AccessKeySecret": "AccessKeySecret", + "SecurityToken": "AQoDYXdzEJr...", + "Expiration": "2023-12-01T12:00:00Z" + }); + + server.set_ecs_metadata(role_name, token_json.clone()); + + let ecs_metadata_url = format!("{}/ram/security-credentials/", server.url().unwrap()); + + // Test without role name + let loader = DLFECSTokenLoader::new(&ecs_metadata_url, None); + let load_token: DLFToken = loader.load_token().await.unwrap(); + + assert_eq!(load_token.access_key_id, "AccessKeyId"); + assert_eq!(load_token.access_key_secret, "AccessKeySecret"); + assert_eq!( + load_token.security_token, + Some("AQoDYXdzEJr...".to_string()) + ); + assert_eq!( + load_token.expiration, + Some("2023-12-01T12:00:00Z".to_string()) + ); + + // Test with role name + let loader_with_role = DLFECSTokenLoader::new(&ecs_metadata_url, Some(role_name.to_string())); + let token: DLFToken = loader_with_role.load_token().await.unwrap(); + + assert_eq!(token.access_key_id, "AccessKeyId"); + assert_eq!(token.access_key_secret, "AccessKeySecret"); + assert_eq!( + token.security_token, + Some("AQoDYXdzEJr...".to_string()) + ); + assert_eq!(token.expiration, Some("2023-12-01T12:00:00Z".to_string())); }