diff --git a/Cargo.lock b/Cargo.lock index ec1f3551..9d26de75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,6 +741,9 @@ name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +dependencies = [ + "getrandom 0.2.16", +] [[package]] name = "figment" @@ -781,6 +784,18 @@ dependencies = [ "spin", ] +[[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "fastrand", + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -940,8 +955,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1189,7 +1206,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.0", + "socket2", "tokio", "tower-service", "tracing", @@ -1999,9 +2016,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -2009,9 +2026,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools", @@ -2022,9 +2039,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -2541,13 +2558,14 @@ dependencies = [ [[package]] name = "sentry_protos" -version = "0.4.11" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae1eac4a748b11a2bb5b342bea8546085751cf9a45e30fb1276b072bb5541e6" +checksum = "5d3c4e8bca4c556eec616dc2594e519248891ca84f8bf958016c2c416223d8ff" dependencies = [ "prost", "prost-types", "tonic", + "tonic-prost", ] [[package]] @@ -2688,16 +2706,6 @@ dependencies = [ "serde", ] -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.0" @@ -2902,7 +2910,7 @@ checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" dependencies = [ "atoi", "chrono", - "flume", + "flume 0.11.1", "futures-channel", "futures-core", "futures-executor", @@ -2992,6 +3000,7 @@ dependencies = [ "derive_builder", "elegant-departure", "figment", + "flume 0.12.0", "futures", "futures-util", "hex", @@ -3166,7 +3175,7 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2 0.6.0", + "socket2", "tokio-macros", "windows-sys 0.59.0", ] @@ -3236,9 +3245,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "axum", @@ -3253,8 +3262,8 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost", - "socket2 0.5.10", + "socket2", + "sync_wrapper", "tokio", "tokio-stream", "tower", @@ -3265,14 +3274,26 @@ dependencies = [ [[package]] name = "tonic-health" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb87334d340313fefa513b6e60794d44a86d5f039b523229c99c323e4e19ca4b" +checksum = "f4ff0636fef47afb3ec02818f5bceb4377b8abb9d6a386aeade18bd6212f8eb7" dependencies = [ "prost", "tokio", "tokio-stream", "tonic", + "tonic-prost", +] + +[[package]] +name = "tonic-prost" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +dependencies = [ + "bytes", + "prost", + "tonic", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 67119a08..a9758e93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ clap = { version = "4.5.20", features = ["derive"] } derive_builder = "0.20.2" elegant-departure = { version = "0.3.1", features = ["tokio"] } figment = { version = "0.10.19", features = ["env", "yaml", "test"] } +flume = "0.12.0" futures = "0.3.31" futures-util = "0.3.31" hex = "0.4.3" @@ -26,8 +27,8 @@ http-body-util = "0.1.2" libsqlite3-sys = "0.30.1" metrics = "0.24.0" metrics-exporter-statsd = "0.9.0" -prost = "0.13" -prost-types = "0.13.3" +prost = "0.14" +prost-types = "0.14" rand = "0.8.5" rdkafka = { version = "0.37.0", features = ["cmake-build", "ssl"] } sentry = { version = "0.41.0", default-features = false, features = [ @@ -41,7 +42,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = "0.4.11" +sentry_protos = "0.8.5" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" @@ -49,8 +50,8 @@ sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "po tokio = { version = "1.43.1", features = ["full"] } tokio-stream = { version = "0.1.16", features = ["full"] } tokio-util = "0.7.12" -tonic = "0.13" -tonic-health = "0.13" +tonic = "0.14" +tonic-health = "0.14" tower = "0.5.1" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = [ diff --git a/README.md b/README.md index 860f446b..7a4f79bc 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ The test suite is composed of unit and integration tests in Rust, and end-to-end ```bash # Run unit/integration tests -make test +make unit-test # Run end-to-end tests make integration-test diff --git a/src/config.rs b/src/config.rs index 67b571d2..2831cae1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -239,6 +239,27 @@ pub struct Config { /// Enable additional metrics for the sqlite. pub enable_sqlite_status_metrics: bool, + + /// Run the taskbroker in push mode (as opposed to pull mode). + pub push_mode: bool, + + /// The number of concurrent dispatchers to run. + pub fetch_threads: usize, + + /// The number of concurrent pushers each dispatcher should run. + pub push_threads: usize, + + /// The size of the push queue. + pub push_queue_size: usize, + + /// The worker service endpoint. + pub worker_endpoint: String, + + /// The hostname used to construct `callback_url` for task push requests. + pub callback_addr: String, + + /// The port used to construct `callback_url` for task push requests. + pub callback_port: u32, } impl Default for Config { @@ -308,6 +329,13 @@ impl Default for Config { full_vacuum_on_upkeep: true, vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, + push_mode: false, + fetch_threads: 1, + push_threads: 1, + push_queue_size: 1, + worker_endpoint: "http://127.0.0.1:50052".into(), + callback_addr: "0.0.0.0".into(), + callback_port: 50051, } } } @@ -712,4 +740,48 @@ mod tests { Ok(()) }); } + + #[test] + fn test_default_push_callback_fields() { + let config = Config::default(); + assert_eq!(config.callback_addr, "0.0.0.0"); + assert_eq!(config.callback_port, 50051); + } + + #[test] + fn test_from_args_push_callback_fields_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_CALLBACK_ADDR", "127.0.0.1"); + jail.set_env("TASKBROKER_CALLBACK_PORT", "51000"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "127.0.0.1"); + assert_eq!(config.callback_port, 51000); + + Ok(()) + }); + } + + #[test] + fn test_from_args_push_callback_fields_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file( + "config.yaml", + r#" + callback_addr: 10.0.0.1 + callback_port: 52000 + "#, + )?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "10.0.0.1"); + assert_eq!(config.callback_port, 52000); + + Ok(()) + }); + } } diff --git a/src/fetch.rs b/src/fetch.rs new file mode 100644 index 00000000..dbd98132 --- /dev/null +++ b/src/fetch.rs @@ -0,0 +1,403 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use tokio::time::sleep; +use tonic::async_trait; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::push::PushPool; +use crate::store::inflight_activation::InflightActivation; +use crate::store::inflight_activation::InflightActivationStore; + +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. +#[async_trait] +pub trait TaskPusher { + /// Push a single task to the worker service. + async fn push_task(&self, activation: InflightActivation) -> Result<()>; +} + +#[async_trait] +impl TaskPusher for PushPool { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.submit(activation).await + } +} + +/// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. +pub struct FetchPool { + /// Inflight activation store. + store: Arc, + + /// Pool of push threads that push activations to the worker service. + pusher: Arc, + + /// Taskbroker configuration. + config: Arc, +} + +impl FetchPool { + /// Initialize a new fetch pool. + pub fn new( + store: Arc, + config: Arc, + pusher: Arc, + ) -> Self { + Self { + store, + config, + pusher, + } + } + + /// Spawn `config.fetch_threads` asynchronous tasks, each of which repeatedly moves pending activations from the store to the push pool until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut handles = vec![]; + + for _ in 0..self.config.fetch_threads.max(1) { + let guard = get_shutdown_guard().shutdown_on_drop(); + + let store = self.store.clone(); + let task_pusher = self.pusher.clone(); + + let handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = guard.wait() => { + info!("Fetch loop received shutdown signal"); + break; + } + + _ = async { + debug!("About to fetch next activation..."); + fetch_activations(store.clone(), task_pusher.clone()).await; + } => {} + } + } + }); + + handles.push(handle); + } + + for handle in handles { + if let Err(e) = handle.await { + return Err(e.into()); + } + } + + Ok(()) + } +} + +/// Grab the next pending activation from the store, mark it as processing, and send to push channel. +pub async fn fetch_activations( + store: Arc, + pusher: Arc, +) { + let start = Instant::now(); + metrics::counter!("fetch.fetch_activations.runs").increment(1); + + debug!("Fetching next pending activation..."); + + match store.get_pending_activation(None, None).await { + Ok(Some(activation)) => { + let id = activation.id.clone(); + debug!("Atomically fetched and marked task {id} as processing"); + + if let Err(e) = pusher.push_task(activation).await { + error!("Failed to submit task {id} to push pool - {:?}", e); + } + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + + Ok(_) => { + debug!("No pending activations, sleeping briefly..."); + sleep(Duration::from_millis(100)).await; + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + + Err(e) => { + error!("Failed to fetch pending activation - {:?}", e); + sleep(Duration::from_millis(100)).await; + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use anyhow::{Error, anyhow}; + use chrono::{DateTime, Utc}; + use tokio::sync::Mutex; + use tokio::time::{Duration, timeout}; + + use super::*; + use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivationStatus, QueryResult, + }; + use crate::test_utils::make_activations; + + #[allow(clippy::large_enum_variant)] + enum MockPendingResult { + Some(InflightActivation), + None, + Err, + } + + /// Fake store for testing. + struct MockStore { + /// How should all calls to `get_pending_activation` respond? + pending_result: MockPendingResult, + + /// How many calls to `get_pending_activation` have been performed? + pending_calls: AtomicUsize, + } + + impl MockStore { + fn new(pending_result: MockPendingResult) -> Self { + let pending_calls = AtomicUsize::new(0); + + Self { + pending_result, + pending_calls, + } + } + } + + #[async_trait] + impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activation( + &self, + _application: Option<&str>, + _namespace: Option<&str>, + ) -> Result, Error> { + self.pending_calls.fetch_add(1, Ordering::SeqCst); + match &self.pending_result { + MockPendingResult::Some(activation) => Ok(Some(activation.clone())), + MockPendingResult::None => Ok(None), + MockPendingResult::Err => Err(anyhow!("mock store error")), + } + } + + async fn get_pending_activations_from_namespaces( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + _limit: Option, + ) -> Result, Error> { + unimplemented!() + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + unimplemented!() + } + + async fn count(&self) -> Result { + unimplemented!() + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + unimplemented!() + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched( + &self, + _killswitched_tasks: Vec, + ) -> Result { + unimplemented!() + } + } + + /// Fake push pool for testing. + struct MockTaskPusher { + /// List of the IDs of all the activations that have been pushed. + pushed_ids: Mutex>, + + /// Should `push_task` fail? + should_fail: bool, + } + + impl MockTaskPusher { + fn new(should_fail: bool) -> Self { + let pushed_ids = Mutex::new(vec![]); + + Self { + pushed_ids, + should_fail, + } + } + } + + #[async_trait] + impl TaskPusher for MockTaskPusher { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.pushed_ids.lock().await.push(activation.id); + + if self.should_fail { + return Err(anyhow!("mock push error")); + } + + Ok(()) + } + } + + #[tokio::test] + async fn fetch_activations_submits_when_pending_exists() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation.clone()))); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1); + assert_eq!(pushed[0], activation.id); + } + + #[tokio::test] + async fn fetch_activations_logs_submit_error_but_does_not_fail() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation))); + let pusher = Arc::new(MockTaskPusher::new(true)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1, "should attempt one push even if it fails"); + } + + #[tokio::test] + async fn fetch_activations_no_pending_does_not_submit() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push if no activation is pending" + ); + } + + #[tokio::test] + async fn fetch_activations_store_error_does_not_submit() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Err)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push when pending activation lookup fails" + ); + } + + #[tokio::test] + async fn fetch_pool_start_spawns_at_least_one_worker_when_fetch_threads_zero() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + let config = Arc::new(Config { + fetch_threads: 0, + ..Config::default() + }); + + let pool = FetchPool::new(store, config, pusher); + + let result = timeout(Duration::from_millis(50), pool.start()).await; + assert!( + result.is_err(), + "start() should not complete immediately when fetch_threads is 0 because .max(1) starts one worker loop" + ); + } +} diff --git a/src/grpc/server.rs b/src/grpc/server.rs index f5ac9292..68d06699 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -9,11 +9,13 @@ use std::sync::Arc; use std::time::Instant; use tonic::{Request, Response, Status}; +use crate::config::Config; use crate::store::inflight_activation::{InflightActivationStatus, InflightActivationStore}; use tracing::{error, instrument}; pub struct TaskbrokerServer { pub store: Arc, + pub config: Arc, } #[tonic::async_trait] @@ -23,6 +25,12 @@ impl ConsumerService for TaskbrokerServer { &self, request: Request, ) -> Result, Status> { + if self.config.push_mode { + return Err(Status::permission_denied( + "Cannot call while broker is in PUSH mode", + )); + } + let start_time = Instant::now(); let application = &request.get_ref().application; let namespace = &request.get_ref().namespace; diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index b99911a2..29e36f96 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -7,7 +7,7 @@ use sentry_protos::taskbroker::v1::{ }; use tonic::{Code, Request}; -use crate::test_utils::{create_test_store, make_activations}; +use crate::test_utils::{create_config, create_test_store, make_activations}; #[tokio::test] #[rstest] @@ -15,7 +15,9 @@ use crate::test_utils::{create_test_store, make_activations}; #[case::postgres("postgres")] async fn test_get_task(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -34,7 +36,9 @@ async fn test_get_task(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete @@ -53,7 +57,9 @@ async fn test_set_task_status(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_invalid(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid @@ -76,10 +82,12 @@ async fn test_set_task_status_invalid(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(1); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -99,6 +107,8 @@ async fn test_get_task_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_application_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -108,7 +118,7 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: Some("hammers".into()), @@ -129,12 +139,14 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: Some(namespace), application: None, @@ -153,10 +165,12 @@ async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str #[allow(deprecated)] async fn test_set_task_status_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, @@ -192,6 +206,8 @@ async fn test_set_task_status_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -201,7 +217,7 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -229,6 +245,8 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -238,7 +256,7 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; // Request a task from an application without any activations. let request = SetTaskStatusRequest { id: "id_0".to_string(), @@ -261,12 +279,14 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete diff --git a/src/lib.rs b/src/lib.rs index 33567944..baf480d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,12 @@ use clap::Parser; use std::fs; pub mod config; +pub mod fetch; pub mod grpc; pub mod kafka; pub mod logging; pub mod metrics; +pub mod push; pub mod runtime_config; pub mod store; pub mod test_utils; diff --git a/src/main.rs b/src/main.rs index 7970939d..0a2ca923 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,9 +2,11 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use std::{sync::Arc, time::Duration}; +use taskbroker::fetch::FetchPool; use taskbroker::kafka::inflight_activation_batcher::{ ActivationBatcherConfig, InflightActivationBatcher, }; +use taskbroker::push::PushPool; use taskbroker::upkeep::upkeep; use tokio::signal::unix::SignalKind; use tokio::task::JoinHandle; @@ -39,16 +41,16 @@ use taskbroker::store::postgres_activation_store::{ use taskbroker::{Args, get_version}; use tonic_health::ServingStatus; -async fn log_task_completion(name: &str, task: JoinHandle>) { +async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { Ok(Ok(())) => { - info!("Task {} completed", name); + info!("Task {} completed", name.as_ref()); } Ok(Err(e)) => { - error!("Task {} failed: {:?}", name, e); + error!("Task {} failed: {:?}", name.as_ref(), e); } Err(e) => { - error!("Task {} panicked: {:?}", name, e); + error!("Task {} panicked: {:?}", name.as_ref(), e); } } } @@ -190,22 +192,24 @@ async fn main() -> Result<(), Error> { // GRPC server let grpc_server_task = tokio::spawn({ - let grpc_store = store.clone(); - let grpc_config = config.clone(); + let store = store.clone(); + let config = config.clone(); + async move { - let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) + let addr = format!("0.0.0.0:{}", config.grpc_port) .parse() .expect("Failed to parse address"); let layers = tower::ServiceBuilder::new() .layer(MetricsLayer::default()) - .layer(AuthLayer::new(&grpc_config)) + .layer(AuthLayer::new(&config)) .into_inner(); let server = Server::builder() .layer(layers) .add_service(ConsumerServiceServer::new(TaskbrokerServer { - store: grpc_store, + store, + config, })) .add_service(health_service.clone()) .serve(addr); @@ -236,7 +240,25 @@ async fn main() -> Result<(), Error> { } }); - elegant_departure::tokio::depart() + // Initialize push and fetch pools + let push_pool = Arc::new(PushPool::new(config.clone())); + let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); + + // Initialize push threads + let push_task = if config.push_mode { + Some(tokio::spawn(async move { push_pool.start().await })) + } else { + None + }; + + // Initialize fetch threads + let fetch_task = if config.push_mode { + Some(tokio::spawn(async move { fetch_pool.start().await })) + } else { + None + }; + + let mut departure = elegant_departure::tokio::depart() .on_termination() .on_sigint() .on_signal(SignalKind::hangup()) @@ -244,8 +266,16 @@ async fn main() -> Result<(), Error> { .on_completion(log_task_completion("consumer", consumer_task)) .on_completion(log_task_completion("grpc_server", grpc_server_task)) .on_completion(log_task_completion("upkeep_task", upkeep_task)) - .on_completion(log_task_completion("maintenance_task", maintenance_task)) - .await; + .on_completion(log_task_completion("maintenance_task", maintenance_task)); + + if let Some(task) = push_task { + departure = departure.on_completion(log_task_completion("push_task", task)); + } + + if let Some(task) = fetch_task { + departure = departure.on_completion(log_task_completion("fetch_task", task)); + } + departure.await; Ok(()) } diff --git a/src/push.rs b/src/push.rs new file mode 100644 index 00000000..3156ec78 --- /dev/null +++ b/src/push.rs @@ -0,0 +1,289 @@ +use std::sync::Arc; +use std::time::Instant; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use flume::{Receiver, Sender}; +use prost::Message; +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use tonic::async_trait; +use tonic::transport::Channel; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::store::inflight_activation::InflightActivation; + +/// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. +#[async_trait] +trait WorkerClient { + /// Send a single `PushTaskRequest` to the worker service. + async fn send(&mut self, request: PushTaskRequest) -> Result<()>; +} + +#[async_trait] +impl WorkerClient for WorkerServiceClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.push_task(request).await?; + Ok(()) + } +} + +/// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. +pub struct PushPool { + /// The sending end of a channel that accepts task activations. + sender: Sender, + + /// The receiving end of a channel that accepts task activations. + receiver: Receiver, + + /// Taskbroker configuration. + config: Arc, +} + +impl PushPool { + /// Initialize a new push pool. + pub fn new(config: Arc) -> Self { + let (sender, receiver) = flume::bounded(config.push_queue_size); + + Self { + sender, + receiver, + config, + } + } + + /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut handles = vec![]; + + for _ in 0..self.config.push_threads { + let endpoint = self.config.worker_endpoint.clone(); + + let callback_url = format!( + "{}:{}", + self.config.callback_addr, self.config.callback_port + ); + + let receiver = self.receiver.clone(); + let guard = get_shutdown_guard().shutdown_on_drop(); + + let handle = tokio::spawn(async move { + let mut worker = match WorkerServiceClient::connect(endpoint).await { + Ok(w) => w, + Err(e) => { + error!("Failed to connect to worker - {:?}", e); + return; + } + }; + + loop { + tokio::select! { + _ = guard.wait() => { + info!("Push worker received shutdown signal"); + break; + } + + message = receiver.recv_async() => { + let activation = match message { + // Received activation from fetch thread + Ok(a) => a, + + // Channel closed + Err(_) => break + }; + + let id = activation.id.clone(); + + match push_task(&mut worker, activation, callback_url.clone()).await { + Ok(_) => debug!("Activation {id} was sent to worker!"), + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e) + }; + } + } + } + + // Drain channel before exiting + for activation in receiver.drain() { + let id = activation.id.clone(); + + match push_task(&mut worker, activation, callback_url.clone()).await { + Ok(_) => debug!("Activation {id} was sent to worker!"), + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e), + }; + } + }); + + handles.push(handle); + } + + for handle in handles { + if let Err(e) = handle.await { + return Err(e.into()); + } + } + + Ok(()) + } + + /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. + pub async fn submit(&self, activation: InflightActivation) -> Result<()> { + Ok(self.sender.send_async(activation).await?) + } +} + +/// Decode task activation and push it to a worker. +async fn push_task( + worker: &mut W, + activation: InflightActivation, + callback_url: String, +) -> Result<()> { + let start = Instant::now(); + let id = activation.id.clone(); + + // Try to decode activation (if it fails, we will see the error where `push_task` is called) + let task = TaskActivation::decode(&activation.activation as &[u8])?; + + let request = PushTaskRequest { + task: Some(task), + callback_url, + }; + + let result = match worker.send(request).await { + Ok(_) => { + debug!("Successfully sent activation {id} to worker service!"); + Ok(()) + } + + Err(e) => { + error!("Could not push activation {id} to worker service - {:?}", e); + Err(e) + } + }; + + metrics::histogram!("push.push_task.duration").record(start.elapsed()); + result +} + +#[cfg(test)] +mod tests { + use anyhow::anyhow; + use std::sync::Arc; + use tokio::time::{Duration, timeout}; + + use super::*; + use crate::test_utils::make_activations; + + /// Fake worker client for unit testing. + struct MockWorkerClient { + /// Capture all received requests so we can assert things about them. + captured_requests: Vec, + + /// Should requests to the worker client fail? + should_fail: bool, + } + + impl MockWorkerClient { + fn new(should_fail: bool) -> Self { + let captured_requests = vec![]; + + Self { + captured_requests, + should_fail, + } + } + } + + #[async_trait] + impl WorkerClient for MockWorkerClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.captured_requests.push(request); + + if self.should_fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } + } + + #[tokio::test] + async fn push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + let callback_url = "taskbroker:50051".to_string(); + + let result = push_task(&mut worker, activation.clone(), callback_url.clone()).await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.captured_requests.len(), 1); + + let request = &worker.captured_requests[0]; + assert_eq!(request.callback_url, callback_url); + assert_eq!( + request.task.as_ref().map(|task| task.id.as_str()), + Some(activation.id.as_str()) + ); + } + + #[tokio::test] + async fn push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.captured_requests.is_empty(), + "worker should not be called if decode fails" + ); + } + + #[tokio::test] + async fn push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + assert!(result.is_err(), "worker send errors should propagate"); + assert_eq!(worker.captured_requests.len(), 1); + } + + #[tokio::test] + async fn push_pool_submit_enqueues_item() { + let config = Arc::new(Config { + push_queue_size: 2, + ..Config::default() + }); + + let pool = PushPool::new(config); + let activation = make_activations(1).remove(0); + + let result = pool.submit(activation).await; + assert!(result.is_ok(), "submit should enqueue activation"); + } + + #[tokio::test] + async fn push_pool_submit_backpressures_when_queue_full() { + let config = Arc::new(Config { + push_queue_size: 1, + ..Config::default() + }); + + let pool = PushPool::new(config); + + let first = make_activations(1).remove(0); + let second = make_activations(1).remove(0); + + pool.submit(first) + .await + .expect("first submit should fill queue"); + + let second_submit = timeout(Duration::from_millis(50), pool.submit(second)).await; + assert!( + second_submit.is_err(), + "second submit should block when queue is full" + ); + } +}